In [1]:
#Use huggingface transformers library for ESM model
import os
os.chdir('/home/azamh/bioinf')
import sys
from transformers import EsmTokenizer, EsmModel, EsmForSequenceClassification
from transformers import TrainingArguments, Trainer
import evaluate
import numpy as np
import torch
from utils.parse_data import *
from utils.rep3d import *
from utils.visuallize import *
from script.models import *
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
torch.backends.cudnn.benchmark = True
from sklearn.metrics import accuracy_score, matthews_corrcoef, classification_report, f1_score


%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Get device 
device = get_device()
print(device)

NVIDIA GeForce RTX 2080 Ti
cuda


In [3]:
print(torch.cuda.get_device_name(0))

NVIDIA GeForce RTX 2080 Ti


In [4]:
#Read dataframe 
deeploc_af2_df = pd.read_excel('data/esm_pred/deeploc_af2_df.xlsx', header = 0, index_col = 0)
deeploc_af2_df

Unnamed: 0,Protein,Sequence,Location,Extra Location,Split,Location Label,PDB Path,Voxel Path,ESM Pred
10700,Q9DA32,MPRTRNIGALCTLPEDTTHSGRPRRGVQRSYISRMAEPAPANMNDP...,Nucleus,M,train,0,data/deeploc_af2/AF-Q9DA32-F1-model_v4.pdb,voxels/deeploc/Q9DA32.pt,0
10701,O42927,MNPTSFIYDKPPPPPIINKPFEQTNSSASLTQKNSSSETENVGRHG...,Nucleus,U,test,0,data/deeploc_af2/AF-O42927-F1-model_v4.pdb,voxels/deeploc/O42927.pt,0
10704,Q8TAS1,MAGSGCAWGAEPPRFLEAFGRLWQVQSRLGSGSSASVYRVRCCGNP...,Nucleus,U,train,0,data/deeploc_af2/AF-Q8TAS1-F1-model_v4.pdb,voxels/deeploc/Q8TAS1.pt,0
10706,Q96WW3,MAKSARSKSIRRNKKVLRENVFQPVIDERTKRLSAHLRDQVNDLTK...,Nucleus,U,train,0,data/deeploc_af2/AF-Q96WW3-F1-model_v4.pdb,voxels/deeploc/Q96WW3.pt,0
10708,Q8VYI0,MAATTGLETLVDQIISVITNDGRNIVGVLKGFDQATNIILDESHER...,Nucleus,U,train,0,data/deeploc_af2/AF-Q8VYI0-F1-model_v4.pdb,voxels/deeploc/Q8VYI0.pt,0
...,...,...,...,...,...,...,...,...,...
11076,Q9S850,MPGIRGPSEYSQEPPRHPSLKVNAKEPFNAEPPRSALVSSYVTPVD...,Peroxisome,U,test,9,data/deeploc_af2/AF-Q9S850-F1-model_v4.pdb,voxels/deeploc/Q9S850.pt,1
11075,P11930,MSSSSSWRRAATVMLAAGWTHSSPAGFRLLLLQRAQNQRFLPGAHV...,Peroxisome,U,train,9,data/deeploc_af2/AF-P11930-F1-model_v4.pdb,voxels/deeploc/P11930.pt,3
11074,Q9LRS0,MEITNVTEYDAIAKAKLPKMVYDYYASGAEDQWTLQENRNAFARIL...,Peroxisome,U,train,9,data/deeploc_af2/AF-Q9LRS0-F1-model_v4.pdb,voxels/deeploc/Q9LRS0.pt,9
11073,A2AKK5,MMIKLIATPSNALVDEPVSIRATGLPPSQIVTIKATVKDENDNVFQ...,Peroxisome,U,train,9,data/deeploc_af2/AF-A2AKK5-F1-model_v4.pdb,voxels/deeploc/A2AKK5.pt,1


In [5]:
#Create fusion model
fusion_model = nn.DataParallel(FusionModel(in_channels = 5, num_classes = 10))
if torch.cuda.is_available():
    fusion_model.cuda()
    print('N Gpus:', torch.cuda.device_count())

N Gpus: 4


In [6]:
#Create combined dataloader
pin_memory = False
num_workers = 0
batch_size = 32

train_deeploc_af2_df, test_deeploc_af2_df = split_deeploc(deeploc_af2_df)

#Create dataset
train_encoding_paths = [f'esm_encoding/{protein}.pt' for protein in train_deeploc_af2_df['Protein']]
test_encoding_paths = [f'esm_encoding/{protein}.pt' for protein in test_deeploc_af2_df['Protein']]
fusion_train_set = FusionDataset(list(train_deeploc_af2_df['Voxel Path']), train_encoding_paths, list(train_deeploc_af2_df['Location Label']))
fusion_test_set = FusionDataset(list(test_deeploc_af2_df['Voxel Path']), test_encoding_paths, list(test_deeploc_af2_df['Location Label']))

#Create loaders
fusion_train_loader = torch.utils.data.DataLoader(fusion_train_set, batch_size=batch_size,
                                              shuffle=True, num_workers = num_workers, pin_memory=pin_memory)
fusion_test_loader = torch.utils.data.DataLoader(fusion_test_set, batch_size=batch_size,
                                              shuffle=True, num_workers = num_workers, pin_memory=pin_memory) 

In [7]:
#Train fusion model
# Cross Entropy Loss 
loss_fn = nn.CrossEntropyLoss()

# SGD Optimizer
optimizer = torch.optim.Adam(fusion_model.parameters(), lr = .001)

In [8]:
model_save_path = f'models/fusion'
epochs = 100

fusion_model = train_fusion_model(fusion_model, 
            epochs, 
            model_save_path, 
            fusion_train_loader,
            fusion_test_loader,
            optimizer,
            loss_fn,
            device)

Epoch: 0
	 Batch 0 Average loss: 2.2962541580200195
	 Batch 1 Average loss: 2.2936692237854004
	 Batch 2 Average loss: 2.174450159072876
	 Batch 3 Average loss: 1.4377466440200806
	 Batch 4 Average loss: 1.0855435132980347
	 Batch 5 Average loss: 0.7153589129447937
	 Batch 6 Average loss: 0.42661601305007935
	 Batch 7 Average loss: 0.41614678502082825
	 Batch 8 Average loss: 0.35630613565444946
	 Batch 9 Average loss: 0.3947293758392334
	 Batch 10 Average loss: 0.70977783203125


KeyboardInterrupt: 