In [28]:
from transformers import AutoTokenizer, EsmForProteinFolding
import torch
import numpy as np
from scipy.special import softmax
from Bio import AlignIO
import os

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
model.to(device)
model

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmForProteinFolding(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 2560, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 2560, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-35): 36 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=2560, out_features=2560, bias=True)
              (key): Linear(in_features=2560, out_features=2560, bias=True)
              (value): Linear(in_features=2560, out_features=2560, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=2560, out_features=2560, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((2560,), eps=1

In [30]:
# grab a sequence from sample cadherin MSA
cadherin_msa_path = os.path.join('..', 'data', 'cadherin', 'PF00028.alignment.seed')
alignment = AlignIO.read(cadherin_msa_path, 'stockholm')
seqs = []
for record in alignment:
    seqs.append(str(record.seq).replace('-', ''))

seq = seqs[0]
seq

'QTVRIKENVPVGTKTIGYKAYDPETGSSSGIRYKKSSDPEGWVDVDKNSGVITILKRLDREARSGVYNISIIASDKDGRTCNGVLGIVLE'

In [31]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
inputs = tokenizer([seq], return_tensors="pt", add_special_tokens=False)  # A tiny random peptide
inputs.to(device)

{'input_ids': tensor([[ 5, 16, 19,  1,  9, 11,  6,  2, 19, 14, 19,  7, 16, 11, 16,  9,  7, 18,
         11,  0, 18,  3, 14,  6, 16,  7, 15, 15, 15,  7,  9,  1, 18, 11, 11, 15,
         15,  3, 14,  6,  7, 17, 19,  3, 19,  3, 11,  2, 15,  7, 19,  9, 16,  9,
         10, 11,  1, 10,  3,  1,  6,  0,  1, 15,  7, 19, 18,  2,  9, 15,  9,  9,
          0, 15,  3, 11,  3,  7,  1, 16,  4,  2,  7, 19, 10,  7,  9, 19, 10,  6]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')}

In [32]:
output = model(**inputs)

In [33]:
# adapted from ColabFold notebook: https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/ESMFold.ipynb#scrollTo=CcyNpAvhTX6q
def parse_output(output):
  pae = (output["aligned_confidence_probs"][0].detach().cpu() * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1].detach().cpu()

  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"].detach().cpu(),-1)[0]

  # NOTE: pretty sure this just looks for all distances <8 angstroms?  
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1].detach().cpu()
  mask = (output["atom37_atom_exists"][0,:,1] == 1).detach().cpu()
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask],
       "sm_contacts":sm_contacts[mask,:][:,mask],
       "xyz":xyz[mask]}
  
  return o

In [34]:
parsed_output = parse_output(output)
parsed_output

  pae = (output["aligned_confidence_probs"][0].detach().cpu() * np.arange(64)).mean(-1) * 31


{'pae': tensor([[3.6453e-03, 1.1457e+00, 2.3981e+00,  ..., 3.0285e+00, 3.3765e+00,
          5.9505e+00],
         [6.6039e-01, 1.1372e-04, 6.5739e-01,  ..., 1.5370e+00, 2.2050e+00,
          2.4775e+00],
         [1.2210e+00, 5.7133e-01, 3.8820e-05,  ..., 9.4825e-01, 1.1879e+00,
          1.5378e+00],
         ...,
         [1.8017e+00, 1.2327e+00, 9.1236e-01,  ..., 8.8048e-06, 5.2342e-01,
          8.9983e-01],
         [2.3614e+00, 1.6779e+00, 1.4129e+00,  ..., 5.2969e-01, 1.3646e-05,
          6.1498e-01],
         [6.5776e+00, 4.7676e+00, 3.7720e+00,  ..., 1.7073e+00, 7.3644e-01,
          3.0903e-04]], dtype=torch.float64),
 'plddt': tensor([0.7583, 0.8268, 0.8352, 0.8453, 0.8415, 0.8417, 0.8347, 0.8235, 0.8495,
         0.8637, 0.8720, 0.8816, 0.8722, 0.8666, 0.7781, 0.7389, 0.7154, 0.7719,
         0.8248, 0.8525, 0.8579, 0.8505, 0.8376, 0.8366, 0.8268, 0.8305, 0.8501,
         0.8377, 0.8228, 0.8411, 0.8860, 0.9037, 0.9015, 0.8958, 0.8795, 0.8751,
         0.8485, 0.8323, 0.78

In [35]:
# pae, plddt == measures of confidence in positions of atoms in structure
parsed_output.keys()

dict_keys(['pae', 'plddt', 'sm_contacts', 'xyz'])

In [36]:
# shape LxL with probabilities of contacts between residue i,j
parsed_output['sm_contacts'].shape

(90, 90)

In [37]:
parsed_output['sm_contacts']

array([[1.0000000e+00, 9.9994224e-01, 8.8756698e-01, ..., 5.9885546e-03,
        5.2144169e-04, 9.7415084e-04],
       [9.9994224e-01, 1.0000000e+00, 9.9997872e-01, ..., 9.8105997e-01,
        5.4218003e-04, 7.3163002e-04],
       [8.8756698e-01, 9.9997872e-01, 1.0000000e+00, ..., 9.9179685e-01,
        1.9427864e-03, 1.4877091e-03],
       ...,
       [5.9885546e-03, 9.8105997e-01, 9.9179685e-01, ..., 1.0000000e+00,
        9.9999678e-01, 9.9157101e-01],
       [5.2144169e-04, 5.4218003e-04, 1.9427864e-03, ..., 9.9999678e-01,
        1.0000000e+00, 9.9983257e-01],
       [9.7415084e-04, 7.3163002e-04, 1.4877091e-03, ..., 9.9157101e-01,
        9.9983257e-01, 1.0000000e+00]], dtype=float32)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, EsmForProteinFolding
from scipy.special import softmax
import numpy as np

from Bio import AlignIO
import os
import argparse
from tqdm import tqdm


# Custom pytorch dataset for holding raw sequence data
class SequenceDataset(Dataset):
    def __init__(self, msa_file):
        super().__init__()
        self.msa_file = msa_file
        # self.tokenizer = tokenizer
        self.seqs = self._get_sequences()
    
    def __len__(self):
        return len(self.seqs)
    
    def __getitem__(self, idx):
        item = self.seqs[idx]
        return item

    def _get_sequences(self):
        alignment = AlignIO.read(self.msa_file, 'stockholm')
        seqs = []
        for record in alignment:
            seqs.append(str(record.seq).replace('-', ''))
        
        return seqs

# adapted from ColabFold notebook: https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/ESMFold.ipynb#scrollTo=CcyNpAvhTX6q
def parse_output(output, batch):
    # pae = (output["aligned_confidence_probs"][0].detach().cpu() * np.arange(64)).mean(-1) * 31
    # plddt = output["plddt"][0,:,1].detach().cpu()

    # bins = np.append(0,np.linspace(2.3125,21.6875,63))
    # sm_contacts = softmax(output["distogram_logits"].detach().cpu(),-1)[0]

    # # NOTE: pretty sure this just looks for all distances <8 angstroms?  
    # sm_contacts = sm_contacts[...,bins<8].sum(-1)
    # xyz = output["positions"][-1,0,:,1].detach().cpu()
    # mask = (output["atom37_atom_exists"][0,:,1] == 1).detach().cpu()
    # o = {"pae":pae[mask,:][:,mask],
    #     "plddt":plddt[mask],
    #     "sm_contacts":sm_contacts[mask,:][:,mask],
    #     "xyz":xyz[mask]}

    bins = np.append(0,np.linspace(2.3125,21.6875,63))
    sm_contacts = softmax(output["distogram_logits"].detach().cpu(),-1)
    sm_contacts = sm_contacts[...,bins<8].sum(-1)
    mask = (output["atom37_atom_exists"][:,:,1] == 1).detach().cpu()

    contacts = []

    # over each protein sequence in batch
    for batch_idx in range(sm_contacts.shape[0]):
        contacts_matrix = sm_contacts[batch_idx][mask[batch_idx],:][:,mask[batch_idx]]
        contacts.append(contacts_matrix)

        # assert that if protein sequence is size L, then contacts_matrix is LxL
        assert len(batch[batch_idx]) == contacts_matrix.shape[0]

    return contacts

# Compute contact matrix for each sequence in MSA using PL model 
def plm_inference_per_sequence(model, 
                               tokenizer, 
                               msa_file, 
                               batch_size,
                               device):

    # setup dataset and dataloader
    dataset = SequenceDataset(msa_file)
    eval_dataloader = DataLoader(dataset, batch_size = batch_size)
    
    contact_matrices = []
    with torch.no_grad():
        for step, batch in enumerate(tqdm(eval_dataloader, f"PLM Inference on {os.path.basename(msa_file)}")):
            inputs = tokenizer(batch, 
                               return_tensors = 'pt', 
                               padding = True,
                            #    truncation=True,
                               add_special_tokens=False)

            inputs.to(device)

            out = model(**inputs)
            contacts = parse_output(out, batch)
            contact_matrices.extend(contacts)
    
    return contact_matrices

parser = argparse.ArgumentParser()
parser.add_argument(
    '--msa_file',
    type=str,
    required=True,
    help="Path of MSA (Stockholm format) file for inference",
)

parser.add_argument(
    '--plm_model',
    type = str,
    required=False,
    default = 'facebook/esmfold_v1',
    help = 'Huggingface model string for Protein Langauge Model'
)

parser.add_argument(
    '--batch_size',
    type = int,
    required=False,
    default=16,
    help = 'Batch size for PLM inference'
)

args_str = '--msa_file /nfshomes/vla/cmsc702-protein-lm/data/cadherin/PF00028.alignment.seed --batch_size 16'.split()
args = parser.parse_args(args_str)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
args

Namespace(msa_file='/nfshomes/vla/cmsc702-protein-lm/data/cadherin/PF00028.alignment.seed', plm_model='facebook/esmfold_v1', batch_size=16)

In [2]:
# Init model
model = EsmForProteinFolding.from_pretrained(args.plm_model)
model.to(device)

# Init tokenizer
tokenizer =  AutoTokenizer.from_pretrained(args.plm_model)


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
contact_matrices = plm_inference_per_sequence(model, 
                            tokenizer,
                            args.msa_file,
                            args.batch_size,
                            device)


PLM Inference on PF00028.alignment.seed: 100%|██████████| 4/4 [00:41<00:00, 10.37s/it]


In [12]:
output_dir = '/nfshomes/vla/cmsc702-protein-lm/results/cadherin/esmfold'
job_name = 'cadherin'

In [13]:
for i,matrix in enumerate(contact_matrices):
    np.save(os.path.join(output_dir, f'{i}_{job_name}.npy'), matrix)

In [None]:
# bins = np.append(0,np.linspace(2.3125,21.6875,63))
# sm_contacts = softmax(out["distogram_logits"].detach().cpu(),-1)
# sm_contacts.shape

# # NOTE: pretty sure this just looks for all distances <8 angstroms?  
# sm_contacts = sm_contacts[...,bins<8].sum(-1)
# sm_contacts.shape
# mask = (out["atom37_atom_exists"][:,:,1] == 1).detach().cpu()
# mask.shape
# sm_contacts[0].shape
# mask[0].shape

# sm_contacts[0][mask[0],:][:,mask[0]]

(1, 90, 90, 64)