In [1]:
from tdc.multi_pred import DTI
from Bio import ExPASy
from Bio import SwissProt
from Bio.PDB import PDBList
from tqdm import tqdm
import pandas as pd
import gc

In [2]:
# import os
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"

In [2]:
%env PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32

env: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32


In [7]:
df = pd.read_csv("davis/raw/data.csv")
lenghts = [len(seq) for seq in df.iloc[:, 2]]
from statistics import mean, median

mean_len, median_len = mean(lenghts), median(lenghts)
mean_len, median_len

(788.947963800905, 707.0)

In [2]:
def extract_esm_feature(dataset):
    from transformers import AutoTokenizer, EsmModel
    import torch

    device = "cuda"

    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D" )
    model = EsmModel.from_pretrained("facebook/esm2_t12_35M_UR50D")
    model.to(device)

    train_path = f"{dataset}/raw/data_train.csv"
    train_save = f"{dataset}/raw/train/"
    test_path = f"{dataset}/raw/data_test.csv"
    test_save = f"{dataset}/raw/test/"

    import os
    os.makedirs(train_save, exist_ok=True)
    os.makedirs(test_save, exist_ok=True)

    train_df = pd.read_csv(train_path)
    test_df = pd.read_csv(test_path)

    batch_size = 1

    for i in tqdm(range(0, train_df.shape[0], batch_size)):
        if batch_size == 1:
            seqs = train_df.iloc[i, 1]
        else:
            seqs = train_df.iloc[i:i+batch_size, 1].tolist()
        inputs = tokenizer(
            seqs,
            add_special_tokens = True,
            max_length = 1200, 
            padding = 'max_length',
            truncation = True,
            return_tensors = 'pt'
        )
        inputs.to(device)
        outputs = model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        embedding = last_hidden_states.cpu().detach().numpy()
        torch.save(torch.Tensor(embedding), train_save+f"{i//batch_size}.pt")

    for i in tqdm(range(0, test_df.shape[0], batch_size)):
        if batch_size == 1:
            seqs = train_df.iloc[i, 1]
        else:
            seqs = train_df.iloc[i:i+batch_size, 1].tolist()
        inputs = tokenizer(
            seqs,
            add_special_tokens = True,
            max_length = 1200, 
            padding = 'max_length',
            truncation = True,
            return_tensors = 'pt'
        )
        inputs.to(device)
        outputs = model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        embedding = last_hidden_states.cpu().detach().numpy()
        torch.save(torch.Tensor(embedding), test_save+f"{i//batch_size}.pt")

In [3]:
extract_esm_feature("davis")

Some weights of the model checkpoint at facebook/esm2_t12_35M_UR50D were not used when initializing EsmModel: ['lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
def uniprot_to_pdb(uniprot_id):
    # Retrieve UniProt entry
    handle = ExPASy.get_sprot_raw(uniprot_id)
    record = SwissProt.read(handle)

    print(record.entry_name)

    # Extract PDB IDs
    pdb_ids = [x[0] for x in record.cross_references if x[0].startswith('PDB:')]

    return pdb_ids

In [2]:
data_davis = DTI(name = 'DAVIS')
data_davis.convert_to_log("binding")
df_davis = data_davis.get_data("df")

Found local copy...
Loading...
Done!
To log space...


In [3]:
df_davis.head(10)

Unnamed: 0,Drug_ID,Drug,Target_ID,Target,Y
0,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,AAK1,MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQV...,7.365523
1,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ABL1p,PFWKILNPLLERGTYYYFMGQQPGKVLGDQRRPSLPALHFIKGAGK...,4.999996
2,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ABL2,MVLGTVLLPPNSYGRDQDTSLCCLCTEASESALPDLTDHFASCVED...,4.999996
3,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ACVR1,MVDGVMILPVLIMIALPSPSMEDEKPKVNPKLYMCVCEGLSCGNED...,4.999996
4,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ACVR1B,MAESAGASSFFPLVVLLLAGSGGSGPRGVQALLCACTSCLQANYTC...,4.999996
5,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ACVR2A,MGAAAKLAFAVFLISCSSGAILGRSETQECLFFNANWEKDRTNQTG...,4.999996
6,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ACVR2B,MTAPWVALALLWGSLCAGSGRGEAETRECIYYNANWELERTNQSGL...,4.999996
7,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ACVRL1,MTLGSPRKGLLMLLMALVTQGDPVKPSRGPLVTCTCESPHCKGPTC...,4.999996
8,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ADCK3,MAAILGDTIMVAKGLVKLTQAAVETHLQHLGIGGELIMAARALQST...,4.999996
9,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ADCK4,MWLKVGGLLRGTGGQLGQTVGWPCGALGPGPHRWGPCGGSWAQKFY...,4.999996


In [6]:
data_kiba = DTI("kiba")
df_kiba = data_kiba.get_data("df")
df_kiba.head(10)

Downloading...
100%|██████████| 96.6M/96.6M [00:22<00:00, 4.31MiB/s]
Loading...
Done!


Unnamed: 0,Drug_ID,Drug,Target_ID,Target,Y
0,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,O00141,MTVKTEAAKGTLTYSRMRGMVAILIAFMKQRRMGLNDFIQKIANNS...,11.1
1,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,O14920,MSWSPSLTTQTCGAWEMKERLGTGGFGNVIRWHNQETGEQIAIKQC...,11.1
2,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,O15111,MERPPGLRPGAGGPWEMRERLGTGGFGNVCLYQHRELDLKIAIKSC...,11.1
3,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,P00533,MRPSGTAGAALLALLAALCPASRALEEKKVCQGTSNKLTQLGTFED...,11.1
4,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,P04626,MELAALCRWGLLLALLPPGAASTQVCTGTDMKLRLPASPETHLDML...,11.1
5,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,P06239,MGCGCSSHPEDDWMENIDVCENCHYPIVPLDGKGTLLIRNGSEVRD...,11.1
6,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,P07333,MGPGVLLLLLVATAWHGQGIPVIEPSVPELVVKPGATVTLRCVGNG...,11.1
7,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,P15056,MAALSGGGGGGAEPGQALFNGDMEPEAGAGAGAAASSAADPAIPEE...,11.1
8,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,P24941,MENFQKVEKIGEGTYGVVYKARNKLTGEVVALKKIRLDTETEGVPS...,11.1
9,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,P28482,MAAAAAAGAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNV...,10.1


In [20]:
ids = uniprot_to_pdb("P28482")
len(ids)

MK01_HUMAN


0

In [7]:
data_bdb = DTI("BindingDB_Kd")
df_bdb = data_bdb.get_data("df")
df_bdb.head(10)

Downloading...
100%|██████████| 54.4M/54.4M [00:15<00:00, 3.56MiB/s]
Loading...
Done!


Unnamed: 0,Drug_ID,Drug,Target_ID,Target,Y
0,444607.0,Cc1ccc(CNS(=O)(=O)c2ccc(S(N)(=O)=O)s2)cc1,P00918,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,0.46
1,4316.0,COc1ccc(CNS(=O)(=O)c2ccc(S(N)(=O)=O)s2)cc1,P00918,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,0.49
2,4293.0,NS(=O)(=O)c1ccc(S(=O)(=O)NCc2cccs2)s1,P00918,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,0.83
3,1611.0,NS(=O)(=O)c1cc2c(s1)S(=O)(=O)N(Cc1cccs1)CC2O,P00918,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,0.2
4,1612.0,COc1ccc(N2CC(O)c3cc(S(N)(=O)=O)sc3S2(=O)=O)cc1,P00918,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,0.16
5,4369102.0,CCN[C@H]1CN(CCOC)S(=O)(=O)c2sc(S(N)(=O)=O)cc21,P00918,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,0.32
6,68844.0,CCN[C@H]1CN(CCCOC)S(=O)(=O)c2sc(S(N)(=O)=O)cc21,P00918,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,0.13
7,1604.0,COc1cccc(N2CCc3cc(S(N)(=O)=O)sc3S2(=O)=O)c1,P00918,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,0.1
8,3013848.0,CN[C@H]1CN(c2cccc(OC)c2)S(=O)(=O)c2sc(S(N)(=O)...,P00918,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,0.1
9,4369101.0,CN[C@@H]1CN(c2cccc(OC)c2)S(=O)(=O)c2sc(S(N)(=O...,P00918,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,1.7


In [4]:
from transformers import AutoTokenizer, EsmModel
import torch

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D" )
model = EsmModel.from_pretrained("facebook/esm2_t30_150M_UR50D")
# model.to("cuda")

Some weights of the model checkpoint at facebook/esm2_t30_150M_UR50D were not used when initializing EsmModel: ['lm_head.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference

In [4]:
df_davis.iloc[0, 3]

'MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQVTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYFSFKLLKKECPIPNVQNSPIPAKLPEPVKASEAAAKKTQPKARLTDPIPTTETSIAPRQRPKAGQTQPNPGILPIQPALTPRKRATVQPPPQAAGSSNQPGLLASVPQPKPQAPPSQPLPQTQAKQPQAPPTPQQTPSTQAQGLPAQAQATPQHQQQLFLKQQQQQQQPPPAQQQPAGTFYQQQQAQTQQFQAVHPATQKPAIAQFPVVSQGGSQQQLMQNFYQQQQQQQQQQQQQQLATALHQQQLMTQQAALQQKPTMAAGQQPQPQPAAAPQPAPAQEPAIQAPVRQQPKVQTTPPPAVQGQKVGSLTPPSSPKTQRAGHRRILSDVTHSAVFGVPASKSTQLLQAAAAEASLNKSKSATTTPSGSPRTSQQNVYNPSEGSTWNPFDDDNFSKLTAEELLNKDFAKLGEGKHPEKLGGSAESLIPGFQSTQGDAFATTSFSAGTAEKRKGGQTVDSGLPLLSVSDPFIPLQVPDAPEKLIEGLKSPDTSLLLPDLLPMTDPFGSTSDAVIEKADVAVESLIPGLEPPVPQRLPSQTESVTSNRTDSLTGEDSLLDCSLLSNPTTDLLEEFAPTAISAPVHKAAEDSNLISGFDVPEGSDKVAEDEFDPIPVLITKNPQGGHSRNSSGSSESSLPNLARSLLLVDQLIDL'

In [None]:
inputs = tokenizer(df_davis.iloc[0, 3], return_tensors="pt", add_special_tokens=False)  # A tiny random peptide
outputs = model(**inputs)
outputs

In [None]:
folded_positions = outputs.positions

In [4]:
model.device

device(type='cpu')

In [5]:
model.to("cuda")

EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 640, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
    (position_embeddings): Embedding(1026, 640, padding_idx=1)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0): EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=640, out_features=640, bias=True)
            (key): Linear(in_features=640, out_features=640, bias=True)
            (value): Linear(in_features=640, out_features=640, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=640, out_features=640, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
          

In [17]:
inputs = tokenizer(df_davis.iloc[0, 3], return_tensors="pt")
outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state

In [18]:
last_hidden_states

tensor([[[ 0.0614,  0.3664,  0.2256,  ...,  0.8194, -0.2384, -0.1046],
         [-0.0200,  0.5158, -0.0029,  ...,  0.6510, -0.3253,  0.1299],
         [-0.3415, -0.1170,  0.1768,  ...,  0.3345, -0.1586,  0.2183],
         ...,
         [ 0.0359,  0.0042, -0.2250,  ...,  0.2451,  0.0614, -0.2504],
         [ 0.1495, -0.0454, -0.1660,  ...,  0.2780, -0.6120, -0.1613],
         [ 0.0572,  0.0423, -0.0627,  ...,  0.5494, -0.7131, -0.2728]]],
       grad_fn=<NativeLayerNormBackward0>)

In [20]:
len(df_davis.iloc[0, 3])

961

In [19]:
last_hidden_states.shape

torch.Size([1, 963, 320])

In [5]:
test_list = df_davis.iloc[130:134, 3].tolist()

In [6]:
inputs2 = tokenizer(
    test_list,
    add_special_tokens = True,
    max_length = 1800, 
    padding = 'max_length',
    truncation = True,
    return_tensors = 'pt'
)
# inputs2.to("cuda")
# print(inputs2.device)

In [7]:
outputs2 = model(**inputs2)

last_hidden_states2 = outputs2.last_hidden_state
print(last_hidden_states2.shape)
print(len(test_list))

torch.Size([4, 1800, 640])
4


In [10]:
last_hidden_states2.shape

torch.Size([1, 1200, 640])