In [1]:
import torch
from torch import nn
import numpy as np
from transformers import AutoTokenizer, AutoModel
from torch.nn import functional as F
from torch.utils.data import Dataset,DataLoader
import pandas as pd
import numpy as np
from Bio import SeqIO
import random
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split,KFold
import torch
from torch import nn
import torch
import numpy as np
import random
import os

#### Protein

In [None]:
esm_model1 = 'DeepChem/esm2_t6_8M_UR50D'
esm_model2 = 'DeepChem/esm2_t30_150M_UR50D'
class Pro_Feature(nn.Module):
   
    def __init__(self,pro_model,max_length,finetune=0):
        super(Pro_Feature, self).__init__()
        self.max_length = max_length
        self.pro_token = AutoTokenizer.from_pretrained(pro_model)
        self.pro_model  = AutoModel.from_pretrained(pro_model)

        if finetune == 0: 
            for param in self.pro_model.parameters():
                param.requires_grad = False
    def forward(self,pro_seq,device):
        protein = self.pro_token(pro_seq,
                                    truncation=True,
                                    padding=True,
                                    max_length=self.max_length, 
                                    add_special_tokens=False)
        input_ids=torch.tensor(protein['input_ids']).unsqueeze(0).to(device)
        attention_mask=torch.tensor(protein['attention_mask']).unsqueeze(0).to(device)
        temp_output=self.pro_model(input_ids=input_ids,attention_mask=attention_mask) 
        pro_feat = torch.mean(temp_output.last_hidden_state,dim=1) 
        # pro_feat = temp_output.last_hidden_state 
        return pro_feat
    
# all sequence
class ProSeq_Feature(nn.Module):
    
    def __init__(self,pro_model,max_length,finetune=0):
        super(ProSeq_Feature, self).__init__()
        self.max_length = max_length
        self.pro_token = AutoTokenizer.from_pretrained(pro_model)
        self.pro_model  = AutoModel.from_pretrained(pro_model)

        if finetune == 0: 
            for param in self.pro_model.parameters():
                param.requires_grad = False
    def forward(self,pro_seq,device):
        protein = self.pro_token(pro_seq,
                                    truncation=True,
                                    padding=True,
                                    max_length=self.max_length, 
                                    add_special_tokens=False)
        input_ids=torch.tensor(protein['input_ids']).unsqueeze(0).to(device)
        attention_mask=torch.tensor(protein['attention_mask']).unsqueeze(0).to(device)
        temp_output=self.pro_model(input_ids=input_ids,attention_mask=attention_mask) 
        
        pro_feat = temp_output.last_hidden_state 
        return pro_feat

In [None]:
df_pro = pd.read_csv('reproductive_targets.csv')

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
model = Pro_Feature(esm_model2,1200).to(device)
all_feat = []
genes = []
for i,row in df_pro.iterrows():
    gene = row['gene']
    seq = row['seq']
    print(gene,len(seq))
    
    feat = model(seq,device)
    all_feat.append(feat)
    genes.append(gene)
all_feat = torch.cat(all_feat).cpu().numpy()
# os.makedirs('Transformer_feats',exist_ok=True)
df_feat = pd.DataFrame(all_feat,index=genes,columns = [str(i) for i in range(all_feat.shape[1])])
df_feat.to_csv('esm_150m.csv')

In [None]:

device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
model = ProSeq_Feature(esm_model2,1200).to(device)

for i,row in df_pro.iterrows():
    gene = row['gene']
    seq = row['seq']
    # if len(seq)>1000:
    
    # if os.path.exists(f'proallfeat/{gene}.pt'):
    #     continue
    print(gene,len(seq))
    feat = model(seq,device).squeeze(0)
    print(gene,': ',feat.shape)
    savefeat = feat
    torch.save(savefeat,f'proallfeat/{gene}.pt')

    del feat
    del savefeat
del model

Some weights of EsmModel were not initialized from the model checkpoint at DeepChem/esm2_t30_150M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ACHE 614
ACHE :  torch.Size([614, 640])
ACVR1 509
ACVR1 :  torch.Size([509, 640])
ADRA1A 466
ADRA1A :  torch.Size([466, 640])
ADRA1B 520
ADRA1B :  torch.Size([520, 640])
ADRA1D 572
ADRA1D :  torch.Size([572, 640])
ADRA2A 465
ADRA2A :  torch.Size([465, 640])
ADRA2B 450
ADRA2B :  torch.Size([450, 640])
ADRA2C 462
ADRA2C :  torch.Size([462, 640])
ADRB1 477
ADRB1 :  torch.Size([477, 640])
ADRB2 413
ADRB2 :  torch.Size([413, 640])
ADRB3 408
ADRB3 :  torch.Size([408, 640])
AGTR1 359
AGTR1 :  torch.Size([359, 640])
AGTR2 363
AGTR2 :  torch.Size([363, 640])
AKR1C3 323
AKR1C3 :  torch.Size([323, 640])
AKT1 480
AKT1 :  torch.Size([480, 640])
ALDH1A1 501
ALDH1A1 :  torch.Size([501, 640])
ALOX5 674
ALOX5 :  torch.Size([674, 640])
APAF1 1000
APAF1 :  torch.Size([1000, 640])
APOBEC3G 384
APOBEC3G :  torch.Size([384, 640])
AR 920
AR :  torch.Size([920, 640])
AURKA 403
AURKA :  torch.Size([403, 640])
AVPR1A 418
AVPR1A :  torch.Size([418, 640])
BMPR1A 532
BMPR1A :  torch.Size([532, 640])
BRAF 766
BRAF 