## Available Checkpoints

In [20]:
ESMs = [ "facebook/esm2_t6_8M_UR50D" ,
         "facebook/esm2_t12_35M_UR50D" ,
         "facebook/esm2_t30_150M_UR50D" ,
         "facebook/esm2_t33_650M_UR50D" ,
         "facebook/esm2_t36_3B_UR50D" ]

Ankhs = [ "ElnaggarLab/ankh-base" , "ElnaggarLab/ankh-large" ]

ProtT5 = [ "Rostlab/prot_t5_xl_uniref50" ]

## Imports and env. variables

In [1]:
#import dependencies
import os.path
os.chdir("set path to git repo here")

import torch

import numpy as np
import pandas as pd
import time

import transformers, datasets

from transformers import T5EncoderModel, T5Tokenizer
from transformers import EsmModel, AutoTokenizer

transformers.logging.set_verbosity_error()

from tqdm import tqdm
import random
import itertools

# Methods

In [3]:
# preprocess (model_type = "esm", "ankh", "pt")

def seq_preprocess(df, model_type = "esm"):
    
    if model_type == "esm":
        return df
        
    elif model_type == "ankh": 
        return df    
        
    elif model_type == "pt": 
        
        df['sequence']=df.apply(lambda row : " ".join(row["sequence"]), axis = 1)  
            
        return df
    
    else: 
        return None
    

In [4]:
def setup_model(checkpoint):

    if "esm" in checkpoint:       
        mod_type = "esm"
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        model = EsmModel.from_pretrained(checkpoint, torch_dtype=torch.float16)
        model = model.to("cuda")
        model = model.half()

    elif "ankh" in checkpoint:
        #half precision does not work with ankh
        mod_type = "ankh"
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        model = T5EncoderModel.from_pretrained(checkpoint)
        model = model.to("cuda")

    else:
        mod_type = "pt" 
        tokenizer = T5Tokenizer.from_pretrained(checkpoint)
        model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16)
        model = model.to("cuda")
        model = model.half()

    return model, tokenizer, mod_type
 
    

In [5]:
#embedding types:
# per Protein  "per_prot"
# per Residue  "per_res"


def create_embedding(checkpoint, df, emb_type = "per_prot"):
    
    model, tokenizer, mod_type = setup_model(checkpoint)
    df = seq_preprocess(df, model_type = mod_type)
    
    emb = []
    
    if emb_type == "per_prot":
    
        for i in tqdm(range(0,len(df))):
            inputs = tokenizer(df["sequence"].loc[i], return_tensors="pt", max_length = 10000, truncation=True, padding=False).to("cuda")

            with torch.no_grad():
                # compute single seq embedding, calculate mean across seq len dimension, transform to np array
                emb.append( np.array( torch.mean( model(**inputs).last_hidden_state.cpu(), dim = 1)))

        #create embedding df
        df_emb = pd.DataFrame(np.concatenate(emb))

        df_emb.reset_index( drop = True, inplace = True)
        
        if mod_type == "pt":
            df["sequence"]  = df['sequence'].replace(' ', '', regex=True)

        df_emb["sequence"] = df["sequence"]
        df_emb["label"] = df["label"]   
        
    elif emb_type == "per_res": 
        
        for i in tqdm(range(0,len(df))):
            inputs = tokenizer(df["sequence"].loc[i], return_tensors="pt", max_length = 10000, truncation=True, padding=False).to("cuda")

            with torch.no_grad():
                # compute single seq embedding, transform to np array
                out = np.array( model(**inputs).last_hidden_state.cpu())
                
                #remove first singleton dimension
                out = np.squeeze(out)
                
                #remove special tokens
                if mod_type in ["pt","ankh"]:
                    # remove last special token
                    out = out[:-1, :]
                    
                elif mod_type == "esm":
                    # remove first and last special token
                    out = out[1:-1, :]                    
                
                
                emb.append(out)

        #create embedding df
        df_emb = pd.DataFrame(np.concatenate(emb))

        df_emb.reset_index( drop = True, inplace = True)
        
        if mod_type == "pt":
            df["sequence"]  = df['sequence'].replace(' ', '', regex=True)
        
        #add coresponding residue and "sequence index"_"position"
        df["pos"] = df['sequence'].str.len()
        df["idx"] = df.index
        df["idx"] = df.apply(lambda x: np.array(x["pos"] * [x["idx"]]) , axis = 1)
        df['pos'] = df['pos'].apply(lambda x: np.array(range(1, x+1)))

        
        idxs = np.concatenate(df['idx'].values)
        poss = np.concatenate(df['pos'].values)
        
        df_emb["seq_idx_pos"] = [str(aa) + "_" + str(bb) for aa, bb in zip(idxs, poss)]
        
        
        seqs = df['sequence'].str.cat()
        
        df_emb["residue"] = [aa for aa in seqs]
        
        #add coresponding label
        labels =  np.concatenate(df['label'].values)
        
        df_emb["label"] = [l for l in labels]
        
        #add coresponding mask
        if "mask" in df.columns:
            masks = df['mask'].str.cat()

            df_emb["mask"] = [m for m in masks]
        
        
        
    else:
        print("input valid embedding type")
        return None
         
    # clean up gpu
    del model
    del tokenizer
    del df
    del inputs
    torch.cuda.empty_cache()
    
    return df_emb

# Example per protein


## Input

In [6]:
path = "./training data/GB1/test.pkl"

In [7]:
# dataframe with sequence and label
df = pd.read_pickle(path).iloc[:,0:2]
df.columns = ["sequence", "label"]
df.head(5)

Unnamed: 0,sequence,label
0,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGAAAEWTYD...,1.61
1,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGAACEWTYD...,3.74
2,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGAAEEWTYD...,0.0
3,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGAAFEWTYD...,1.08
4,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGAAFEWTYD...,2.09


In [8]:
# replace non common AAs
df["sequence"]=df["sequence"].str.replace('|'.join(["O","B","U","Z","J"]),"X",regex=True)

## Create Embedding

In [9]:
emb = create_embedding(ESMs[2], df, emb_type = "per_prot")

100%|██████████| 5743/5743 [02:29<00:00, 38.52it/s]


In [10]:
# embedding columns + sequence and label are added at the end
emb

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,632,633,634,635,636,637,638,639,sequence,label
0,-0.114624,0.003830,-0.048645,-0.120544,-0.027832,-0.149902,-0.030777,0.015610,-0.241089,0.039703,...,-0.622070,-0.065857,-0.164185,0.000514,-0.027084,-0.014130,-0.083069,-0.090881,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGAAAEWTYD...,1.61
1,-0.122803,0.008003,-0.049133,-0.126465,-0.027176,-0.150391,-0.032745,0.015305,-0.244873,0.041656,...,-0.617188,-0.068420,-0.165649,-0.001229,-0.026840,-0.014725,-0.079163,-0.098999,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGAACEWTYD...,3.74
2,-0.115051,0.012665,-0.044800,-0.126099,-0.027664,-0.143433,-0.036102,0.017960,-0.242554,0.047607,...,-0.605957,-0.078247,-0.156128,-0.004128,-0.025986,-0.009087,-0.084351,-0.091980,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGAAEEWTYD...,0.00
3,-0.120300,0.004383,-0.058014,-0.123596,-0.029541,-0.151978,-0.030014,0.017731,-0.238770,0.038879,...,-0.622559,-0.065735,-0.169678,-0.003948,-0.021332,-0.013313,-0.081177,-0.089417,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGAAFEWTYD...,1.08
4,-0.119995,0.007362,-0.055481,-0.123108,-0.030670,-0.148926,-0.029053,0.022125,-0.238647,0.043671,...,-0.617676,-0.070923,-0.162476,-0.005623,-0.019211,-0.012268,-0.082214,-0.092407,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGAAFEWTYD...,2.09
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5738,-0.122559,0.006401,-0.055389,-0.134033,-0.022736,-0.154419,-0.034119,0.013084,-0.260742,0.041412,...,-0.608398,-0.086609,-0.151733,-0.010902,-0.021255,-0.020660,-0.079285,-0.092285,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYYSEWTYD...,0.37
5739,-0.125000,0.004883,-0.051453,-0.134155,-0.029785,-0.154907,-0.033875,0.016190,-0.260742,0.041595,...,-0.605957,-0.086609,-0.150269,-0.011421,-0.021606,-0.017960,-0.080200,-0.097473,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYYVEWTYD...,1.04
5740,-0.126953,-0.001006,-0.055328,-0.137329,-0.022827,-0.154297,-0.033600,0.015175,-0.261963,0.037231,...,-0.602539,-0.086487,-0.154053,-0.010155,-0.026733,-0.019516,-0.080322,-0.099121,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYYVEWTYD...,0.00
5741,-0.121948,0.002306,-0.055573,-0.139404,-0.019562,-0.153931,-0.030869,0.011703,-0.264160,0.041412,...,-0.606934,-0.082031,-0.146606,-0.013870,-0.020340,-0.021683,-0.079712,-0.094727,MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGYYYEWTYD...,0.03


# Example per residue


## Input

In [11]:
path = "./training data/SecStr/test.pkl"

In [12]:
df = pd.read_pickle(path).iloc[:,:3]
df.columns = ["sequence", "label", "mask"]
df.head(5)

Unnamed: 0,sequence,label,mask
0,MTPAVTTYKLVINGKTLKGETTTKAVDAETAEKAFKQYANDNGVDG...,"[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ...",0000111111111111111111111111111111111111111111...
1,MNDQEKIDKFTHSYINDDFGLTIDQLVPKVKGYGRFNVWLGGNESK...,"[0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 1, 0, ...",0111111111111111111111111111111111111111111111...
2,GPGFMRDSGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCET...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0000000000000000000000111111111111111111111111...
3,SNALSRNEVLLNGDINFKEVRCVGDNGEVYGIISSKEALKIAQNLG...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 0, ...",0000000111111111111111111111111111111111111111...
4,MAKGKSEVVEQNHTLILGWSDKLGSLLNQLAIANESLGGGTIAVMA...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, ...",0000000000111111111111111111111111111111111111...


In [13]:
# see if label length correspond to sequence
print("Label length correspond to seq:", all(df.sequence.str.len() == df.label.str.len()) )

Label length correspond to seq: True


In [14]:
# see if mask length correspond (if you have a mask)
print("Mask length correspond to seq:", all(df.sequence.str.len() == df["mask"].str.len()) )

Mask length correspond to seq: True


In [15]:
#replace non common AAs
df["sequence"]=df["sequence"].str.replace('|'.join(["O","B","U","Z","J"]),"X",regex=True)

## Create Embedding

In [16]:
# embedding columns + original sequence (index in input df and residue position), residue, label and mask are added at the end
emb = create_embedding(ESMs[2], df, emb_type = "per_res")

100%|██████████| 364/364 [00:09<00:00, 37.40it/s]


In [17]:
emb

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,634,635,636,637,638,639,seq_idx_pos,residue,label,mask
0,-0.052795,0.017120,0.078491,-0.093079,0.200195,0.410400,0.108154,-0.546875,-0.291016,0.350342,...,0.372803,0.375000,-0.090637,-0.082581,-0.432861,-0.090027,0_1,M,0,0
1,-0.407227,0.119629,-0.003481,0.324463,0.098633,-0.297607,0.079224,-0.544434,0.007919,0.194580,...,-0.208374,0.558105,0.013992,0.183594,-0.359619,-0.291260,0_2,T,0,0
2,-0.407959,-0.047546,0.151367,0.133423,-0.210938,-0.009377,0.187378,-0.518555,-0.221802,0.303223,...,-0.186890,0.105469,-0.190918,-0.147949,-0.155762,0.043518,0_3,P,0,0
3,-0.069153,-0.224365,0.114380,0.312988,-0.119995,0.210083,0.095093,-0.447998,-0.253662,0.113159,...,-0.341553,0.432861,-0.073242,-0.188354,0.008858,-0.022873,0_4,A,0,0
4,-0.311523,-0.020187,0.256592,0.226074,-0.264648,0.322021,0.077698,-0.239868,0.287109,0.155396,...,0.076660,0.267090,0.139893,-0.139893,-0.059296,-0.036041,0_5,V,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
84536,0.080933,0.051422,-0.087708,0.128662,-0.157471,0.511719,-0.035950,0.074341,-0.122498,0.193970,...,0.110046,-0.320312,0.024063,-0.219604,-0.624512,0.034454,363_451,W,0,1
84537,-0.053589,0.122986,-0.095825,-0.165405,-0.008362,0.051880,-0.133545,0.189941,0.030411,-0.070557,...,-0.018097,-0.001393,0.051666,-0.202148,-0.500977,-0.025070,363_452,X,0,0
84538,0.096436,0.238403,0.124268,-0.165283,0.310791,-0.026337,0.037628,0.294678,-0.050385,0.167969,...,-0.037415,-0.191040,-0.223755,-0.001476,-0.496582,0.124512,363_453,X,0,0
84539,-0.101440,0.251953,0.022568,-0.163574,0.200562,0.130005,0.056854,0.158569,-0.015137,0.238525,...,-0.016693,-0.071411,-0.226074,-0.026001,-0.583008,0.154419,363_454,X,0,0


# GB1 - ESM2 8M example

1. Creates embeddings for training, validation and test set of the GB1 task, using the ESM2 8M model
2. Embedding dataframes are saved to the GB1_embedding_ESM2_8M folder in pickle format

In [21]:
for data in ["test", "valid", "train"]:
    path = "./notebooks/embedding/example_data/GB1_raw/" + data + ".pkl"
    
    df = pd.read_pickle(path).iloc[:,0:2]
    df.columns = ["sequence", "label"]
    
    df["sequence"]=df["sequence"].str.replace('|'.join(["O","B","U","Z","J"]),"X",regex=True)
    emb = create_embedding(ESMs[0], df, emb_type = "per_prot")
    emb.to_pickle("./notebooks/embedding/example_data/GB1_embedded/" + data + "_ESM2_8M.pkl")

100%|██████████| 5743/5743 [00:32<00:00, 179.40it/s]
100%|██████████| 299/299 [00:01<00:00, 177.21it/s]
100%|██████████| 2691/2691 [00:15<00:00, 175.56it/s]
