## Generating Contact Maps from the Categorical Jacobian of the Protein Language Model ESM-2

1. set up ESM-2 protein language model 
2. process input data (FASTA files)
3. calculate the categorical jacobian for each protein sequence
    a. Tokenise the amino acid sequence, with padding (x)
    b. Get the length of the sequence (ln)
    c. At every position in the sequence (ln), create a matrix of all possible token sequences when the amino acid at this position is mutated
        Feed all tokenised sequences into the model, and extract the output logits 
        > Note that each of these matrices will represent the likelihood that each possible amino acid will be present in this position (*difference) 
4. Condense the resulting LxAxLxA jacobian to an LxL matrix to produce the contact map prediction for the protein 

In [8]:
#%pip install matplotlib
%run utils.ipynb

In [9]:
#1. Set up ESM-2 Protein Language Model

In [10]:
# Source: https://github.com/facebookresearch/esm/tree/main?tab=readme-ov-file#esmfold
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t33_650M_UR50D")

Using cache found in C:\Users\aguba/.cache\torch\hub\facebookresearch_esm_main


In [4]:
batch_converter = alphabet.get_batch_converter()
# model.eval() 

In [5]:
#2. Process input data (FASTA files)

In [15]:
folder_path = "uniprot_sequences"
all_data,lengths = process_uniprot_folder(folder_path)
data = [(protein[0],protein[1]) for protein in all_data] # get id and sequence only
print(data)

[('P01116', 'MTEYKLVVVGAGGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKCDLPSRTVDTKQAQDLARSYGIPFIETSAKTRQRVEDAFYTLVREIRQYRLKKISKEEKTPGCVKIKKCIIM'), ('P31749', 'MSDVAIVKEGWLHKRGEYIKTWRPRYFLLKNDGTFIGYKERPQDVDQREAPLNNFSVAQCQLMKTERPRPNTFIIRCLQWTTVIERTFHVETPEEREEWTTAIQTVADGLKKQEEEEMDFRSGSPSDNSGAEEMEVSLAKPKHRVTMNEFEYLKLLGKGTFGKVILVKEKATGRYYAMKILKKEVIVAKDEVAHTLTENRVLQNSRHPFLTALKYSFQTHDRLCFVMEYANGGELFFHLSRERVFSEDRARFYGAEIVSALDYLHSEKNVVYRDLKLENLMLDKDGHIKITDFGLCKEGIKDGATMKTFCGTPEYLAPEVLEDNDYGRAVDWWGLGVVMYEMMCGRLPFYNQDHEKLFELILMEEIRFPRTLGPEAKSLLSGLLKKDPKQRLGGGSEDAKEIMQHRFFAGIVWQHVYEKKLSPPFKPQVTSETDTRYFDEEFTAQMITITPPDQDDSMECVDSERRPHFPQFSYSASGTA')]


In [17]:
cjs = {}
#3. Calculate the categorical jacobian for each protein sequence
def get_cj_by_index(i):
    protein = data[i]
    # ∂in/∂out
    x = batch_converter([protein])[-1] 
    ln = lengths[i] 
    cj = get_categorical_jacobian(x,ln,model)
    # add protein with its corresponding categorical jacobian to dict
    cjs[protein[0]]=cj
get_cj_by_index(1) # 2M7D

  0%|                                                                                          | 0/480 [51:32<?, ?it/s]


KeyboardInterrupt: 

In [None]:
cjs = {}
for i in range(0,len(data)): #len(data) - only take the first
    protein = data[i]
    # ∂in/∂out
    x = batch_converter([protein])[-1] 
    ln = lengths[i] 
    cj = get_categorical_jacobian(x,ln,model)
    # add protein with its corresponding categorical jacobian to dict
    cjs[protein[0]]=cj

In [None]:
#4. from the categorical jacobian, produce contact maps for each resulting protein
import matplotlib.pyplot as plt

In [None]:
#Source: https://github.com/zzhangzzhang/pLMs-interpretability/blob/main/jac/01_jac_calculate_visualise.ipynb
# 1UNQ is at index 2
plt.figure(figsize=(5,5))
plt.imshow(get_contacts(cjs["1UNQ"]))
plt.show()

In [None]:
# options for generating contact maps: mdtraj, biopython, own custom function close to how the model generates it (see code for the model if its open source)