This notebook was adapted from [https://github.com/RSchmirler/data-repo_plm-finetune-eval/tree/main](https://github.com/RSchmirler/data-repo_plm-finetune-eval/tree/main).

The goal is extract embeddings from pretrained protein language models.

## Setup

In [2]:
#import dependencies
import torch

import numpy as np
import pandas as pd
import time

import transformers, datasets

from transformers import EsmModel, AutoTokenizer
from transformers import T5EncoderModel, T5Tokenizer

transformers.logging.set_verbosity_error()

from tqdm import tqdm
import random
import itertools

In [None]:
ESMs = [ "esm2_t6_8M_UR50D" ,
         "esm2_t12_35M_UR50D" ,
         "esm2_t30_150M_UR50D" ,
         "esm2_t33_650M_UR50D",
         "esm2_t36_3B_UR50D"]

ProtT5 = ["prot_t5_xl_uniref50"] 

# Environment to run this notebook


These are the versions of the core packages we use to run this notebook:

In [3]:
print("Torch version: ",torch.__version__)
print("Numpy version: ",np.__version__)
print("Pandas version: ",pd.__version__)
print("Transformers version: ",transformers.__version__)

Torch version:  2.5.1+cu124
Numpy version:  1.26.4
Pandas version:  2.2.3
Transformers version:  4.46.2


# Methods

In [4]:
def setup_model(checkpoint):

    if "esm" in checkpoint:       
        tokenizer = AutoTokenizer.from_pretrained(f'facebook/{checkpoint}')
        model = EsmModel.from_pretrained(f'facebook/{checkpoint}', torch_dtype=torch.float16)
        model = model.to("cuda")
        model = model.half()

    else:
        tokenizer = T5Tokenizer.from_pretrained(f'Rostlab/{checkpoint}')
        model = T5EncoderModel.from_pretrained(f'Rostlab/{checkpoint}', torch_dtype=torch.float16)
        model = model.to("cuda")
        model = model.half()
        
    return model, tokenizer

In [5]:
def create_embedding(checkpoint, df, max_len=1024, seq_colname="Sequence", id_colname="Entry"):
    
    model, tokenizer = setup_model(checkpoint)
    
    emb = []

    if "prot_t5" in checkpoint:
        df[seq_colname]=df.apply(lambda row : " ".join(row[seq_colname]), axis = 1)
    
    for i in tqdm(range(0,len(df))):
        inputs = tokenizer(df[seq_colname].loc[i], return_tensors="pt", max_length = max_len, truncation=True, padding=False).to("cuda")

        with torch.no_grad():
            # compute single seq embedding, transform to np array
            out = np.array( model(**inputs).last_hidden_state.cpu())
            
            #remove first singleton dimension
            out = np.squeeze(out)

            if 'esm' in checkpoint:
                # remove first and last special token
                out = out[1:-1, :]                    
            else:
                out = out[:-1, :]

            np.save(f"../data/embeddings/{df[id_colname].loc[i]}_{checkpoint}.npy", out)
    

## Create Embedding

In [6]:
df = pd.read_csv("../data/uniprot_all_human_proteins.txt.gz", sep='\t')
df = df[['Entry', 'Sequence']]
df

Unnamed: 0,Entry,Sequence
0,A0A087X1C5,MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...
1,A0A0B4J2F0,MFRRLTFAQLLFATVLGIAGGVYIFQPVFEQYAKDQKELKEKMQLV...
2,A0A0B4J2F2,MVIMSEFSADPAGQGQGQQKPLRVGFYDIERTLGKGNFAVVKLARH...
3,A0A0C5B5G6,MRWQEMGYIFYPRKLR
4,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...
...,...,...
20429,Q9UI54,MESPKCLYSRITVNTAFGTKFSHISFIILFKVFLFPRITISKKTKL...
20430,Q9UI72,MGMALELYWLCGFRSYWPLGTNAENEGNRKENRRQMQSRNERGCNV...
20431,Q9Y3F1,MSLLWTPQILTISFVSYILSLFPSPFPSCYTSCWFETSITTEKELN...
20432,Q9Y6C7,MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW...


In [7]:
for esm in ESMs:
    print(esm)
    create_embedding(esm, df)

In [18]:
create_embedding(ProtT5[0], df)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20434/20434 [08:49<00:00, 38.58it/s]
