In [1]:
import numpy as np
import pandas as pd
import csv
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import time
from sklearn.metrics import roc_auc_score

In [2]:
CUDA = 0
SEED = 600

ESM_FILE = "esm2_650m_out"
ESM_DIM = 1280
ESM_LAYER = 33

# ESM_FILE = "esm2_3b_out"
# ESM_DIM = 2560
# ESM_LAYER = 36

LR = 0.001
EPOCH = 50

PROJ_PMHC_DIM_MI = 50
PROJ_TCR_DIM_MI = 70
FEAT_DIM = 70

TEMPERATURE = 0.1

K_NEG = 5

BATCH_SIZE = (1+K_NEG)*50

In [3]:
def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

In [4]:
setup_seed(SEED)

In [5]:
directory = os.fsencode("/project/DPDS/Wang_lab/shared/pMTnet_v2/data/ipd_imgt/" + ESM_FILE)
mhc_dic = {}
for file in os.listdir(directory):
    filename = os.fsdecode(file)
    esm_file = torch.load(os.path.join("/project/DPDS/Wang_lab/shared/pMTnet_v2/data/ipd_imgt/" + ESM_FILE, filename))
    mhc_dic[filename.split(".")[0]] = esm_file["representations"][ESM_LAYER]

In [6]:
aa_dict_atchley=dict()
aa_dict_dir='/work/DPDS/s213303/pmtnetv2/test_data/pmtnetv1/pMTnet-master/library/Atchley_factors.csv'
with open(aa_dict_dir,'r') as aa:
    aa_reader=csv.reader(aa)      
    next(aa_reader, None)
    for rows in aa_reader:
        aa_name=rows[0]
        aa_factor=rows[1:len(rows)]
        aa_dict_atchley[aa_name]=np.asarray(aa_factor,dtype='float')

In [7]:
def peptideMap(dataset, aa_dict, column, padding):
    peptideArray = np.zeros((len(dataset), 1, padding, 5), dtype=np.float32)
    for pos, seq in enumerate(dataset[column]):
        peptideArray[pos, 0] = aamapping(seq, aa_dict, padding)
    return peptideArray

In [8]:
def aamapping(peptideSeq, aa_dict, padding):
    peptideArray = []
    if len(peptideSeq)>padding:
        #print('Length: '+str(len(peptideSeq))+'is over bound'+ ' (' +str(padding)+ ')' +'!')
        peptideSeq=peptideSeq[0:padding]
    for aa_single in peptideSeq:
        try:
            peptideArray.append(aa_dict[aa_single])
        except:
#            print('Inproper aa: ' + aa_single + ', in seq: ' + peptideSeq + '. 0 was applied for encoding.')
            peptideArray.append(np.zeros(5, dtype='float32'))
    return np.concatenate((np.asarray(peptideArray), np.zeros((padding - len(peptideSeq), 5), dtype='float32')), axis=0)

In [9]:
def mhcMap(dataset, allele, mhc_dic):
    mhc_array = np.zeros((len(dataset), 1, 380, ESM_DIM), dtype=np.float32)
    mhc_seen = dict()
    for pos, mhc in enumerate(dataset[allele]):
        try:
            mhc_array[pos, 0] = mhc_seen[mhc]
        except:
            if len(mhc)>380:
                print('Length: '+str(len(mhc))+'is over bound!')
                mhc=mhc[0:380]
            mhc_array[pos, 0] = esmmapping(mhc,mhc_dic)
            mhc_seen[mhc] = mhc_array[pos, 0]
    return mhc_array

In [10]:
def esmmapping(mhc,mhc_dic):
    mhc_encoding = mhc_dic[mhc].numpy()
    num_padding = 380-mhc_encoding.shape[0]
    return np.concatenate((mhc_encoding, np.zeros((num_padding,ESM_DIM),dtype='float32')), axis=0)

In [11]:
def preprocess(filedir, mhc_dic, a_allele, b_allele): 
    #1. input file path is valid or not
    print('Processing: '+filedir)
    if not os.path.exists(filedir):
        print('Invalid file path: ' + filedir)
        return 0
    dataset = pd.read_csv(filedir, header=0, sep="\t")
    print("Number of rows in raw dataset: " + str(dataset.shape[0]))
    dataset=dataset.dropna()
    print("Number of rows in this dataset after dropping NA: " + str(dataset.shape[0]))
    #2. antigen peptide longer than 30 will be dropped
    num_row = dataset.shape[0]
    dataset_antigen_dropped = dataset[dataset.peptide.str.len()>30]
    dataset=dataset[dataset.peptide.str.len()<=30]
    if((num_row-dataset.shape[0])>0):
        print(str(num_row-dataset.shape[0])+' antigens longer than ' + str(30) + 'aa are dropped:')
        print(dataset_antigen_dropped)
    #3. input MHC that is not in the ESM dictionary will be dropped
    num_row = dataset.shape[0]
    mhc_dic_keys = set(mhc_dic.keys())
    dataset_mhc_alpha_dropped = dataset[~dataset[a_allele].isin(mhc_dic_keys)]
    dataset_mhc_beta_dropped = dataset[~dataset[b_allele].isin(mhc_dic_keys)]
    dataset = dataset[dataset[a_allele].isin(mhc_dic_keys)]
    dataset = dataset[dataset[b_allele].isin(mhc_dic_keys)]
    if((num_row-dataset.shape[0])>0):
        print(str(num_row-dataset.shape[0])+' MHCs without ESM embedding are dropped:')
        print(pd.unique(dataset_mhc_alpha_dropped[a_allele]))
        print(pd.unique(dataset_mhc_beta_dropped[b_allele]))
#    dataset = dataset.sample(frac=1)
    dataset = dataset.reset_index(drop=True)
    print("Number of rows in processed dataset: " + str(dataset.shape[0]))
    return dataset

In [12]:
# model 1
class pMHC(nn.Module):
    def __init__(self):
        super(pMHC, self).__init__()
        self.layerP1 = nn.Sequential(
            nn.Conv2d(1, 200,(2,5)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(2,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(2,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.MaxPool2d((30-3*2+3*1,1)),
            nn.Flatten(),
            nn.Linear(200, int(200/2)),
            nn.ReLU(),
            nn.Linear(int(200/2), 3)
        )
        self.layerP2 = nn.Sequential(
            nn.Conv2d(1, 200,(4,5)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(4,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(4,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.MaxPool2d((30-3*4+3*1,1)),
            nn.Flatten(),
            nn.Linear(200, int(200/2)),
            nn.ReLU(),
            nn.Linear(int(200/2), 3)
        )
        self.layerP3 = nn.Sequential(
            nn.Conv2d(1, 200,(6,5)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(6,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(6,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.MaxPool2d((30-3*6+3*1,1)),
            nn.Flatten(),
            nn.Linear(200, int(200/2)),
            nn.ReLU(),
            nn.Linear(int(200/2), 3)
        )
        self.layerA1 = nn.Sequential(
            nn.Conv2d(1, 200,(10,ESM_DIM)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(10,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(10,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.MaxPool2d((380-3*10+3*1,1)),
            nn.Flatten(),
            nn.Linear(200, int(200/2)),
            nn.ReLU(),
            nn.Linear(int(200/2), 3)
        )
        self.layerA2 = nn.Sequential(
            nn.Conv2d(1, 200,(20,ESM_DIM)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(20,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(20,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.MaxPool2d((380-3*20+3*1,1)),
            nn.Flatten(),
            nn.Linear(200, int(200/2)),
            nn.ReLU(),
            nn.Linear(int(200/2), 3)
        )
        self.layerA3 = nn.Sequential(
            nn.Conv2d(1, 200,(30,ESM_DIM)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(30,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(30,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.MaxPool2d((380-3*30+3*1,1)),
            nn.Flatten(),
            nn.Linear(200, int(200/2)),
            nn.ReLU(),
            nn.Linear(int(200/2), 3)
        )
        self.layerB1 = nn.Sequential(
            nn.Conv2d(1, 200,(10,ESM_DIM)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(10,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(10,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.MaxPool2d((380-3*10+3*1,1)),
            nn.Flatten(),
            nn.Linear(200, int(200/2)),
            nn.ReLU(),
            nn.Linear(int(200/2), 3)
        )
        self.layerB2 = nn.Sequential(
            nn.Conv2d(1, 200,(20,ESM_DIM)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(20,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(20,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.MaxPool2d((380-3*20+3*1,1)),
            nn.Flatten(),
            nn.Linear(200, int(200/2)),
            nn.ReLU(),
            nn.Linear(int(200/2), 3)
        )
        self.layerB3 = nn.Sequential(
            nn.Conv2d(1, 200,(30,ESM_DIM)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(30,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.Conv2d(200, 200,(30,1)),
            nn.BatchNorm2d(200),
            nn.ReLU(),
            nn.MaxPool2d((380-3*30+3*1,1)),
            nn.Flatten(),
            nn.Linear(200, int(200/2)),
            nn.ReLU(),
            nn.Linear(int(200/2), 3)
        )
        self.fc1 = nn.Linear(3 * 9, 30)
        self.fc2 = nn.Linear(30, 3)
    def forward(self, x_p, x_a, x_b):
        f1_p = self.layerP1(x_p)
        f2_p = self.layerP2(x_p)
        f3_p = self.layerP3(x_p)
        
        f1_a = self.layerA1(x_a)
        f2_a = self.layerA2(x_a)
        f3_a = self.layerA3(x_a)
        
        f1_b = self.layerB1(x_b)
        f2_b = self.layerB2(x_b)
        f3_b = self.layerB3(x_b)
        encoded = self.fc1(torch.cat((f1_p,f2_p,f3_p, f1_a,f2_a,f3_a, f1_b,f2_b,f3_b),dim=1))
        encoded_act = F.relu(encoded)
        return encoded_act, self.fc2(encoded_act)

In [13]:
# 2. V gene alpha model
class vGdVAEa(nn.Module):
    def __init__(self):
        super(vGdVAEa, self).__init__()
        self.layer1 = nn.Sequential(  
            nn.Conv2d(1, 180,(10,5)),
            nn.ReLU(),   
            nn.MaxPool2d((100-10+1,1)),
            nn.Flatten(), 
            nn.Linear(180, int(180/2)),    
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(1, 180,(20,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-20+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(1, 180,(30,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-30+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(1, 180,(40,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-40+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(1, 180,(50,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-50+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer6 = nn.Sequential(
            nn.Conv2d(1, 180,(60,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-60+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer7 = nn.Sequential(
            nn.Conv2d(1, 180,(70,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-70+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer8 = nn.Sequential(
            nn.Conv2d(1, 180,(80,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-80+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer9 = nn.Sequential(
            nn.Conv2d(1, 180,(90,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-90+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer10 = nn.Sequential(
            nn.Conv2d(1, 180,(100,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-100+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=5, nhead=5), num_layers=6)
        self.linear1 = nn.Linear(in_features=3*10, out_features=int(3*10/2))
        self.linear2 = nn.Linear(in_features=int(3*10/2), out_features=5)
        self.linear3 = nn.Linear(in_features=int(3*10/2), out_features=5)
        self.decoder = nn.Sequential(
            nn.Linear(in_features=5, out_features=int(180*10/2)),
            nn.ReLU(),
            nn.Linear(in_features=int(180*10/2), out_features=180*10),
            nn.Unflatten(1,(10,180,1)),
            nn.Conv2d(in_channels=10, out_channels=1, kernel_size=(180-100+1,5), padding=(0,4))
        )
    def forward(self, x):
        padding_mask = torch.sum(x.squeeze(dim=1),dim=2) == 0
        x = x.permute(2,1,0,3)
        x = torch.squeeze(x,dim=1)
        x = self.transformer_encoder(src = x,src_key_padding_mask = padding_mask)
        x = torch.unsqueeze(x,dim=1)
        x = x.permute(2,1,0,3)
        f1 = self.layer1(x)
        f2 = self.layer2(x)
        f3 = self.layer3(x)
        f4 = self.layer4(x)
        f5 = self.layer5(x)
        f6 = self.layer6(x)
        f7 = self.layer7(x)
        f8 = self.layer8(x)
        f9 = self.layer9(x)
        f10 = self.layer10(x)
        h1 = F.relu(self.linear1(torch.cat((f1,f2,f3,f4,f5,f6,f7,f8,f9,f10),1)))
        mu = self.linear2(h1)
        logvar = self.linear3(h1)
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        encoded = mu + eps*std
        decoded = self.decoder(encoded)
        return encoded, decoded, mu, logvar

In [14]:
# 3 v gene beta model
class vGdVAEb(nn.Module):
    def __init__(self):
        super(vGdVAEb, self).__init__()   
        self.layer1 = nn.Sequential(  
            nn.Conv2d(1, 180,(10,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-10+1,1)),
            nn.Flatten(), 
            nn.Linear(180, int(180/2)),    
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(1, 180,(20,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-20+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(1, 180,(30,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-30+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(1, 180,(40,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-40+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(1, 180,(50,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-50+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer6 = nn.Sequential(
            nn.Conv2d(1, 180,(60,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-60+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer7 = nn.Sequential(
            nn.Conv2d(1, 180,(70,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-70+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer8 = nn.Sequential(
            nn.Conv2d(1, 180,(80,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-80+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer9 = nn.Sequential(
            nn.Conv2d(1, 180,(90,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-90+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.layer10 = nn.Sequential(
            nn.Conv2d(1, 180,(100,5)),
            nn.ReLU(),
            nn.MaxPool2d((100-100+1,1)),
            nn.Flatten(),
            nn.Linear(180, int(180/2)),
            nn.ReLU(),
            nn.Linear(int(180/2), 3)
        )
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=5, nhead=5), num_layers=6)
        self.linear1 = nn.Linear(in_features=3*10, out_features=int(3*10/2))
        self.linear2 = nn.Linear(in_features=int(3*10/2), out_features=5)
        self.linear3 = nn.Linear(in_features=int(3*10/2), out_features=5)
        self.decoder = nn.Sequential(
            nn.Linear(in_features=5, out_features=int(180*10/2)),
            nn.ReLU(),
            nn.Linear(in_features=int(180*10/2), out_features=180*10),
            nn.Unflatten(1,(10,180,1)),
            nn.Conv2d(in_channels=10, out_channels=1, kernel_size=(180-100+1,5), padding=(0,4))
        )
    def forward(self, x):
        padding_mask = torch.sum(x.squeeze(dim=1),dim=2) == 0
        x = x.permute(2,1,0,3)
        x = torch.squeeze(x,dim=1)
        x = self.transformer_encoder(src = x,src_key_padding_mask = padding_mask)
        x = torch.unsqueeze(x,dim=1)
        x = x.permute(2,1,0,3)
        f1 = self.layer1(x)
        f2 = self.layer2(x)
        f3 = self.layer3(x)
        f4 = self.layer4(x)
        f5 = self.layer5(x)
        f6 = self.layer6(x)
        f7 = self.layer7(x)
        f8 = self.layer8(x)
        f9 = self.layer9(x)
        f10 = self.layer10(x)
        h1 = F.relu(self.linear1(torch.cat((f1,f2,f3,f4,f5,f6,f7,f8,f9,f10),1)))
        mu = self.linear2(h1)
        logvar = self.linear3(h1)
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        encoded = mu + eps*std
        decoded = self.decoder(encoded)
        return encoded, decoded, mu, logvar

In [15]:
# 4 CDR3 alpha model
class cdr3VAEa(nn.Module):
    def __init__(self):
        super(cdr3VAEa, self).__init__()
        self.layer1 = nn.Sequential(  
            nn.Conv2d(1, 150,(1,5)),
            nn.ReLU(),   
            nn.MaxPool2d((25,1)),
            nn.Flatten(), 
            nn.Linear(150, int(150/2)),    
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(1, 150,(2,5)),
            nn.ReLU(),
            #nn.BatchNorm2d(30),
            nn.MaxPool2d((25-2+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(1, 150,(3,5)),
            nn.ReLU(),
            #nn.BatchNorm2d(30),
            nn.MaxPool2d((25-3+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(1, 150,(4,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-4+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(1, 150,(5,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-5+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer6 = nn.Sequential(
            nn.Conv2d(1, 150,(6,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-6+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer7 = nn.Sequential(
            nn.Conv2d(1, 150,(7,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-7+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer8 = nn.Sequential(
            nn.Conv2d(1, 150,(8,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-8+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer9 = nn.Sequential(
            nn.Conv2d(1, 150,(9,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-9+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer10 = nn.Sequential(
            nn.Conv2d(1, 150,(10,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-10+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer11 = nn.Sequential(
            nn.Conv2d(1, 150,(11,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-11+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer12 = nn.Sequential(
            nn.Conv2d(1, 150,(12,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-12+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer13 = nn.Sequential(
            nn.Conv2d(1, 150,(13,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-13+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer14 = nn.Sequential(
            nn.Conv2d(1, 150,(14,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-14+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer15 = nn.Sequential(
            nn.Conv2d(1, 150,(15,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-15+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer16 = nn.Sequential(
            nn.Conv2d(1, 150,(16,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-16+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer17 = nn.Sequential(
            nn.Conv2d(1, 150,(17,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-17+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer18 = nn.Sequential(
            nn.Conv2d(1, 150,(18,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-18+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer19 = nn.Sequential(
            nn.Conv2d(1, 150,(19,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-19+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer20 = nn.Sequential(
            nn.Conv2d(1, 150,(20,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-20+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer21 = nn.Sequential(
            nn.Conv2d(1, 150,(21,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-21+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer22 = nn.Sequential(
            nn.Conv2d(1, 150,(22,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-22+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer23 = nn.Sequential(
            nn.Conv2d(1, 150,(23,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-23+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer24 = nn.Sequential(
            nn.Conv2d(1, 150,(24,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-24+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer25 = nn.Sequential(
            nn.Conv2d(1, 150,(25,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-25+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=5, nhead=5), num_layers=6)
        self.linear1 = nn.Linear(in_features=3*25, out_features=30)
        self.linear2 = nn.Linear(in_features=3*25, out_features=30)
        self.decoder = nn.Sequential(
            nn.Linear(in_features=30, out_features=int(150*25/2)),
            nn.ReLU(),
            nn.Linear(in_features=int(150*25/2), out_features=150*25),
            nn.Unflatten(1,(25,150,1)),
            nn.Conv2d(in_channels=25, out_channels=1, kernel_size=(150-25+1,5), padding=(0,4))
        )
    def forward(self, x):
        padding_mask = torch.sum(x.squeeze(dim=1),dim=2) == 0
        x = x.permute(2,1,0,3)
        x = torch.squeeze(x,dim=1)
        x = self.transformer_encoder(src = x,src_key_padding_mask = padding_mask)
        x = torch.unsqueeze(x,dim=1)
        x = x.permute(2,1,0,3)
        f1 = self.layer1(x)
        f2 = self.layer2(x)
        f3 = self.layer3(x)
        f4 = self.layer4(x)
        f5 = self.layer5(x)
        f6 = self.layer6(x)
        f7 = self.layer7(x)
        f8 = self.layer8(x)
        f9 = self.layer9(x)
        f10 = self.layer10(x)
        f11 = self.layer11(x)
        f12 = self.layer12(x)
        f13 = self.layer13(x)
        f14 = self.layer14(x)
        f15 = self.layer15(x)
        f16 = self.layer16(x)
        f17 = self.layer17(x)
        f18 = self.layer18(x)
        f19 = self.layer19(x)
        f20 = self.layer20(x)
        f21 = self.layer21(x)
        f22 = self.layer22(x)
        f23 = self.layer23(x)
        f24 = self.layer24(x)
        f25 = self.layer25(x)
        mu = self.linear1(torch.cat((f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13,f14,f15,f16,f17,f18,f19,f20,f21,f22,f23,f24,f25),1))
        logvar = self.linear2(torch.cat((f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13,f14,f15,f16,f17,f18,f19,f20,f21,f22,f23,f24,f25),1))
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        encoded = mu + eps*std
        decoded = self.decoder(encoded)
        return encoded, decoded, mu, logvar

In [16]:
# 5 cdr3 beta
class cdr3VAEb(nn.Module):
    def __init__(self):
        super(cdr3VAEb, self).__init__()
        self.layer1 = nn.Sequential(  
            nn.Conv2d(1, 150,(1,5)),
            nn.ReLU(),   
            nn.MaxPool2d((25,1)),
            nn.Flatten(), 
            nn.Linear(150, int(150/2)),    
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(1, 150,(2,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-2+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(1, 150,(3,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-3+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(1, 150,(4,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-4+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(1, 150,(5,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-5+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer6 = nn.Sequential(
            nn.Conv2d(1, 150,(6,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-6+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer7 = nn.Sequential(
            nn.Conv2d(1, 150,(7,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-7+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer8 = nn.Sequential(
            nn.Conv2d(1, 150,(8,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-8+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer9 = nn.Sequential(
            nn.Conv2d(1, 150,(9,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-9+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer10 = nn.Sequential(
            nn.Conv2d(1, 150,(10,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-10+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer11 = nn.Sequential(
            nn.Conv2d(1, 150,(11,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-11+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer12 = nn.Sequential(
            nn.Conv2d(1, 150,(12,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-12+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer13 = nn.Sequential(
            nn.Conv2d(1, 150,(13,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-13+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer14 = nn.Sequential(
            nn.Conv2d(1, 150,(14,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-14+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer15 = nn.Sequential(
            nn.Conv2d(1, 150,(15,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-15+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer16 = nn.Sequential(
            nn.Conv2d(1, 150,(16,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-16+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer17 = nn.Sequential(
            nn.Conv2d(1, 150,(17,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-17+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer18 = nn.Sequential(
            nn.Conv2d(1, 150,(18,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-18+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer19 = nn.Sequential(
            nn.Conv2d(1, 150,(19,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-19+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer20 = nn.Sequential(
            nn.Conv2d(1, 150,(20,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-20+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer21 = nn.Sequential(
            nn.Conv2d(1, 150,(21,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-21+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer22 = nn.Sequential(
            nn.Conv2d(1, 150,(22,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-22+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer23 = nn.Sequential(
            nn.Conv2d(1, 150,(23,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-23+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer24 = nn.Sequential(
            nn.Conv2d(1, 150,(24,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-24+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.layer25 = nn.Sequential(
            nn.Conv2d(1, 150,(25,5)),
            nn.ReLU(),
            nn.MaxPool2d((25-25+1,1)),
            nn.Flatten(),
            nn.Linear(150, int(150/2)),
            nn.ReLU(),
            nn.Linear(int(150/2), 3)
        )
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=5, nhead=5), num_layers=6)
        self.linear1 = nn.Linear(in_features=3*25, out_features=30)
        self.linear2 = nn.Linear(in_features=3*25, out_features=30)
        self.decoder = nn.Sequential(
            nn.Linear(in_features=30, out_features=int(150*25/2)),
            nn.ReLU(),
            nn.Linear(in_features=int(150*25/2), out_features=150*25),
            nn.Unflatten(1,(25,150,1)),
            nn.Conv2d(in_channels=25, out_channels=1, kernel_size=(150-25+1,5), padding=(0,4))
        )
    def forward(self, x):
        padding_mask = torch.sum(x.squeeze(dim=1),dim=2) == 0
        x = x.permute(2,1,0,3)
        x = torch.squeeze(x,dim=1)
        x = self.transformer_encoder(src = x,src_key_padding_mask = padding_mask)
        x = torch.unsqueeze(x,dim=1)
        x = x.permute(2,1,0,3)
        f1 = self.layer1(x)
        f2 = self.layer2(x)
        f3 = self.layer3(x)
        f4 = self.layer4(x)
        f5 = self.layer5(x)
        f6 = self.layer6(x)
        f7 = self.layer7(x)
        f8 = self.layer8(x)
        f9 = self.layer9(x)
        f10 = self.layer10(x)
        f11 = self.layer11(x)
        f12 = self.layer12(x)
        f13 = self.layer13(x)
        f14 = self.layer14(x)
        f15 = self.layer15(x)
        f16 = self.layer16(x)
        f17 = self.layer17(x)
        f18 = self.layer18(x)
        f19 = self.layer19(x)
        f20 = self.layer20(x)
        f21 = self.layer21(x)
        f22 = self.layer22(x)
        f23 = self.layer23(x)
        f24 = self.layer24(x)
        f25 = self.layer25(x)
        mu = self.linear1(torch.cat((f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13,f14,f15,f16,f17,f18,f19,f20,f21,f22,f23,f24,f25),1))
        logvar = self.linear2(torch.cat((f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13,f14,f15,f16,f17,f18,f19,f20,f21,f22,f23,f24,f25),1))
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        encoded = mu + eps*std
        decoded = self.decoder(encoded)
        return encoded, decoded, mu, logvar

In [17]:
def set_model(model_device):
    
    pMHCcheckpoint = torch.load("/work/DPDS/s213303/pmtnetv2/script/pytorch_model/pMHC_Copy42_Seed3000Channel200Batch200LR0.001Epoch20.pth",map_location=model_device)
    vGdVAEacheckpoint = torch.load("/work/DPDS/s213303/pmtnetv2/script/pytorch_model/vgene_dvae_Copy26_5neuron_Seed100Channel180Batch100LR0.0001Epoch390.pth",map_location=model_device)
    vGdVAEbcheckpoint = torch.load("/work/DPDS/s213303/pmtnetv2/script/pytorch_model/vgene_dvae_Copy29_5neuron_Seed100Channel180Batch100LR0.0001Epoch365.pth",map_location=model_device)
    cdr3VAEacheckpoint = torch.load("/work/DPDS/s213303/pmtnetv2/script/pytorch_model/cdr3_a_vae_Copy128_Seed100Bottle30Batch200Lr0.0005N_Trans6Channel150Embed3Epoch225.pth",map_location=model_device)
    cdr3VAEbcheckpoint = torch.load("/work/DPDS/s213303/pmtnetv2/script/pytorch_model/cdr3_b_vae_Copy132_Seed100Bottle30Batch200Lr0.0005N_Trans6Channel150Embed3Epoch170.pth",map_location=model_device)
    
    CLmodel_new = pMHCTCR(temperature=TEMPERATURE).to(model_device)
    
    pMHCmodel_loaded = pMHC().to(model_device)
    pMHCmodel_loaded.load_state_dict(pMHCcheckpoint['net'])

    vGdVAEamodel_loaded = vGdVAEa().to(model_device)
    vGdVAEamodel_loaded.load_state_dict(vGdVAEacheckpoint['net'])

    vGdVAEbmodel_loaded = vGdVAEb().to(model_device)
    vGdVAEbmodel_loaded.load_state_dict(vGdVAEbcheckpoint['net'])

    cdr3VAEamodel_loaded = cdr3VAEa().to(model_device)
    cdr3VAEamodel_loaded.load_state_dict(cdr3VAEacheckpoint['net'])

    cdr3VAEbmodel_loaded = cdr3VAEb().to(model_device)
    cdr3VAEbmodel_loaded.load_state_dict(cdr3VAEbcheckpoint['net'])
    
    return CLmodel_new, pMHCmodel_loaded, vGdVAEamodel_loaded, vGdVAEbmodel_loaded, cdr3VAEamodel_loaded, cdr3VAEbmodel_loaded

In [18]:
def pmhcEncoder(model, source_dataset, model_device):
    x_p = torch.Tensor(peptideMap(source_dataset, aa_dict_atchley, "peptide", 30)).to(model_device)
    x_a = torch.Tensor(mhcMap(source_dataset, "mhca", mhc_dic)).to(model_device)
    x_b = torch.Tensor(mhcMap(source_dataset, "mhcb", mhc_dic)).to(model_device)
    encoded, output = model(x_p, x_a, x_b)
    return encoded

In [19]:
def tcrEncoder(model, source_dataset, column, padding, model_device):
    seq = torch.Tensor(peptideMap(source_dataset, aa_dict_atchley, column, padding)).to(model_device)
    encoded, recon, mu, logvar = model(seq)
    encoded[torch.isnan(encoded).all(dim=1)] = 0
    return encoded

In [20]:
# input dataset, like vcdr3pmhc
# output Zpmhc * Ztcr
class pMHCTCR(nn.Module):
    def __init__(self, temperature):
        super(pMHCTCR, self).__init__()
        self.temperature = temperature
        # Proj for pMHC
        self.Proj1 = nn.Sequential(
            nn.Linear(30, PROJ_PMHC_DIM_MI),
            nn.ReLU(),
            nn.Linear(PROJ_PMHC_DIM_MI, FEAT_DIM)
        )
        # Proj for TCR dim_in is 5*2+30*2
        self.Proj2 = nn.Sequential(
            nn.Linear(70, PROJ_TCR_DIM_MI),
            nn.ReLU(),
            nn.Linear(PROJ_TCR_DIM_MI, FEAT_DIM)
        )
    def forward(self, pmhc, tcr):
        Zpmhc = F.normalize(self.Proj1(pmhc))
        Ztcr = F.normalize(self.Proj2(tcr))
        logits = torch.div(torch.diagonal(torch.mm(Zpmhc,Ztcr.T)),self.temperature)
        return logits

In [21]:
# it needs to be sure that there is at least 1 positive pair in the input data
def LossFunction2(logits, label, pmhc_int, model_device):
    # 1.1 matric for summing (all pairs)
    label_pmhc_sum = torch.zeros(pmhc_int.max()+1, len(pmhc_int)).to(device)
    label_pmhc_sum[pmhc_int, torch.arange(len(pmhc_int))] = 1
    
    # 1.2 sum
    log_sum_exp_prob = torch.log(torch.matmul(label_pmhc_sum, torch.exp(logits)))

    #----------------  --------------------------------------------------------
    # 2.1 matrix for mean (only positive pairs)
    label_pmhc_mean = F.normalize(label_pmhc_sum * label.repeat(pmhc_int.max()+1,1), p=1, dim=1)
    
    # 2.2 mean of positive
    mean_pos_prob = torch.matmul(label_pmhc_mean, logits)
    
    loss = torch.sum(log_sum_exp_prob - mean_pos_prob)
    
    return loss

In [22]:
def train(dataset, model, pmhcmodel, vamodel, vbmodel, cdr3amodel, cdr3bmodel, loss_fn, optimizer, model_device):
    model.train()
    pmhcmodel.train()
    vamodel.train()
    vbmodel.train()
    cdr3amodel.train()
    cdr3bmodel.train()
    train_loss = 0.0
    num_dataset = dataset.shape[0]
    
    for loop, batch in enumerate(range(0, num_dataset, BATCH_SIZE)):
                
        batch_data = dataset[batch:(batch+BATCH_SIZE)].reset_index(drop=True)
                    
        pmhc_batch_uniq = batch_data[["peptide","mhca","mhcb","randn"]].drop_duplicates()
        pmhc_embedding = nn.functional.normalize(pmhcEncoder(pmhcmodel, pmhc_batch_uniq, model_device))
        pmhc_embedding = pmhc_embedding.repeat_interleave(repeats=(1+K_NEG),dim=0)
        
        va_embedding = nn.functional.normalize(tcrEncoder(vamodel, batch_data, "vaseq", 100, model_device))
        
        vb_embedding = nn.functional.normalize(tcrEncoder(vbmodel, batch_data, "vbseq", 100, model_device))
        
        cdr3a_embedding = nn.functional.normalize(tcrEncoder(cdr3amodel, batch_data, "cdr3a", 25, model_device))
        
        cdr3b_embedding = nn.functional.normalize(tcrEncoder(cdr3bmodel, batch_data, "cdr3b", 25, model_device))
        
        tcr_embedding = torch.cat((va_embedding,vb_embedding,cdr3a_embedding,cdr3b_embedding),dim=1)
        
        optimizer.zero_grad()
        
        cos = model(pmhc_embedding, tcr_embedding)
        
        loss = loss_fn(cos,torch.from_numpy(batch_data['label'].values).to(model_device),torch.from_numpy(pd.Categorical(batch_data['randn']).codes.astype("int64")),model_device)
        
        #loss.backward(retain_graph=True)
        loss.backward()
        
        optimizer.step()
        
        train_loss += loss.item()
    
    return train_loss/(num_dataset/(1+K_NEG))

    #return train_loss/((loop+1)*PMHC_PER_BATCH)
    #return train_loss

In [23]:
# val with AUROC
# def val(dataset, model, pmhcmodel, vamodel, vbmodel, cdr3amodel, cdr3bmodel, model_device):
#     model.eval()
#     pmhcmodel.eval()
#     vamodel.eval()
#     vbmodel.eval()
#     cdr3amodel.eval()
#     cdr3bmodel.eval()
#     results = []
#     labels = []
#     num_dataset = dataset.shape[0]
    
#     with torch.no_grad():
#         for loop, batch in enumerate(range(0, num_dataset, BATCH_SIZE)):
#             batch_data = dataset[batch:(batch+BATCH_SIZE)].reset_index(drop=True)
        
#             #pmhc_data = batch_data[['mhca','mhcb','peptide']]
#             pmhc_embedding = nn.functional.normalize(pmhcEncoder(pmhcmodel, batch_data, model_device))

#             #va_data = batch_data["vaseq"].to_frame(name="vaseq")
#             va_embedding = nn.functional.normalize(tcrEncoder(vamodel, batch_data, "vaseq", 100, model_device))

#             #vb_data = batch_data["vbseq"].to_frame(name="vbseq")
#             vb_embedding = nn.functional.normalize(tcrEncoder(vbmodel, batch_data, "vbseq", 100, model_device))

#             #cdr3a_data = batch_data["cdr3a"].to_frame(name="cdr3a")
#             cdr3a_embedding = nn.functional.normalize(tcrEncoder(cdr3amodel, batch_data, "cdr3a", 25, model_device))

#             #cdr3b_data = batch_data["cdr3b"].to_frame(name="cdr3b")
#             cdr3b_embedding = nn.functional.normalize(tcrEncoder(cdr3bmodel, batch_data, "cdr3b", 25, model_device))

#             tcr_embedding = torch.cat((va_embedding,vb_embedding,cdr3a_embedding,cdr3b_embedding),dim=1)
            
#             cos = model(pmhc_embedding, tcr_embedding)
            
#             results.append(cos.detach().cpu().numpy())
#             labels.append(batch_data['label'].values)
    
#     results_array = np.concatenate(results, axis=0)
#     labels_array = np.concatenate(labels, axis=0)
#     auroc = roc_auc_score(labels_array, results_array)
#     return auroc

In [24]:
# val with loss
def val(dataset, model, pmhcmodel, vamodel, vbmodel, cdr3amodel, cdr3bmodel, loss_fn, model_device):
    model.eval()
    pmhcmodel.eval()
    vamodel.eval()
    vbmodel.eval()
    cdr3amodel.eval()
    cdr3bmodel.eval()
    val_loss = 0.0
    num_dataset = dataset.shape[0]
    
    with torch.no_grad():
        for loop, batch in enumerate(range(0, num_dataset, BATCH_SIZE)):
            batch_data = dataset[batch:(batch+BATCH_SIZE)].reset_index(drop=True)
        
            pmhc_embedding = nn.functional.normalize(pmhcEncoder(pmhcmodel, batch_data, model_device))

            va_embedding = nn.functional.normalize(tcrEncoder(vamodel, batch_data, "vaseq", 100, model_device))

            vb_embedding = nn.functional.normalize(tcrEncoder(vbmodel, batch_data, "vbseq", 100, model_device))

            cdr3a_embedding = nn.functional.normalize(tcrEncoder(cdr3amodel, batch_data, "cdr3a", 25, model_device))

            cdr3b_embedding = nn.functional.normalize(tcrEncoder(cdr3bmodel, batch_data, "cdr3b", 25, model_device))

            tcr_embedding = torch.cat((va_embedding,vb_embedding,cdr3a_embedding,cdr3b_embedding),dim=1)
            
            cos = model(pmhc_embedding, tcr_embedding)
            
            loss = loss_fn(cos,torch.from_numpy(batch_data['label'].values).to(model_device),torch.from_numpy(pd.Categorical(batch_data['valrandn']).codes.astype("int64")),model_device)
        
            val_loss += loss.item()
    
    return val_loss/(num_dataset/(1+K_NEG))
    #return val_loss/((loop+1)*PMHC_PER_BATCH)
    #return val_loss

In [25]:
def get_script_num():
    script_name_tmp1 = os.path.basename(__file__)
    script_name_tmp2 = script_name_tmp1.split("-")
    script_name_tmp3 = script_name_tmp2[1].split(".")
    script_num = script_name_tmp3[0]
    return script_num

In [26]:
def compare_models(model_1, model_2):
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if torch.equal(key_item_1[1], key_item_2[1]):
            pass
        else:
            models_differ += 1
            if (key_item_1[0] == key_item_2[0]):
                pass
                #print('Mismtach found at', key_item_1[0])
            else:
                raise Exception
    if models_differ == 0:
        print('Models match perfectly! :)')
    else:
        print("Models doesn't match")

In [27]:
# A*08:01 couldn't be found
# A*02:15N
# H-2-IAg7 couldn't be found
training_dataset = preprocess("/project/DPDS/Wang_lab/shared/pMTnet_v2/data/vcdr3pmhc/pos1negk50/pairing_training_Dec19_all.txt", mhc_dic, "mhca", "mhcb")

Processing: /project/DPDS/Wang_lab/shared/pMTnet_v2/data/vcdr3pmhc/pos1negk50/pairing_training_Dec19_all.txt
Number of rows in raw dataset: 33149400
Number of rows in this dataset after dropping NA: 33149400
27000 MHCs without ESM embedding are dropped:
['A*08:01' 'H-2-IAg7_alpha' 'A*02:15']
['H-2-IAg7_beta']
Number of rows in processed dataset: 33122400


In [28]:
training_dataset[:6]

Unnamed: 0,vb,va,peptide,mhca,mhcb,pMHC_SPECIES,TCR_SPECIES,class,randn,cdr3a,cdr3b,label,vaseq,vbseq
0,XXX,XXX,KLGGALQAK,A*03:01,human_microglobulin,human,human,7,93,CAPSGGSYIPTF,CASRFEGSTGELFF,0,XXX,XXX
1,XXX,XXX,KLGGALQAK,A*03:01,human_microglobulin,human,human,7,93,CAYRRWGAQKLVFF,CASKTGTLRTGPYEQYFF,1,XXX,XXX
2,XXX,XXX,KLGGALQAK,A*03:01,human_microglobulin,human,human,7,93,CTGNTPLVF,CAWGLGTGAQPQHF,0,XXX,XXX
3,XXX,XXX,KLGGALQAK,A*03:01,human_microglobulin,human,human,7,93,CALSEGGNQGGKLIF,CASSSDQKAGTFYEQFF,0,XXX,XXX
4,XXX,XXX,KLGGALQAK,A*03:01,human_microglobulin,human,human,7,93,CAVGHNNARLMF,CASSIQTGSLGGYTF,0,XXX,XXX
5,XXX,XXX,KLGGALQAK,A*03:01,human_microglobulin,human,human,7,93,CAVNKRDSSYKLIF,CASSLKASGNTGELFF,0,XXX,XXX


In [29]:
training_1batch = preprocess("/project/DPDS/Wang_lab/shared/pMTnet_v2/data/vcdr3pmhc/pos1negk50/pairing_training_Dec19_100_5.txt", mhc_dic, "mhca", "mhcb")

Processing: /project/DPDS/Wang_lab/shared/pMTnet_v2/data/vcdr3pmhc/pos1negk50/pairing_training_Dec19_100_5.txt
Number of rows in raw dataset: 662988
Number of rows in this dataset after dropping NA: 662988
540 MHCs without ESM embedding are dropped:
['A*08:01' 'H-2-IAg7_alpha' 'A*02:15']
['H-2-IAg7_beta']
Number of rows in processed dataset: 662448


In [30]:
PAIR_PER_EPOCH = len(training_1batch)

In [31]:
validation_dataset=preprocess("/project/DPDS/Wang_lab/shared/pMTnet_v2/data/vcdr3pmhc/pos1negk50/pairing_validation_Dec19_100_5.txt", mhc_dic, "mhca", "mhcb")

Processing: /project/DPDS/Wang_lab/shared/pMTnet_v2/data/vcdr3pmhc/pos1negk50/pairing_validation_Dec19_100_5.txt
Number of rows in raw dataset: 5268
Number of rows in this dataset after dropping NA: 5268
Number of rows in processed dataset: 5268


In [32]:
validation_dataset[0:12]

Unnamed: 0,vb,va,data_file,peptide,mhca,mhcb,pMHC_SPECIES,TCR_SPECIES,valrandn,cdr3a,cdr3b,label,vaseq,vbseq
0,TRBV11-3*01,TRAV12-2*01,Yu,KAFSPEVIPMF,B*57:01,human_microglobulin,human,human,72,CAVKLRPGNTPLVF,CASSPTRVSSYNEQFF,0,QKEVEQNSGPLSVPEGAIASLNCTYSDRGSQSFFWYRQYSGKSPEL...,EAGVVQSPRYKIIEKKQPVAFWCNPISGHNTLYWYLQNLGQGPELL...
1,TRBV11-3*01,TRAV19*01,Yu,KAFSPEVIPMF,B*57:01,human_microglobulin,human,human,72,CALSEDILTGGGNKLTF,CASSPTGPRNYGYTF,0,AQKVTQAQTEISVVEKEDVTLDCVYETRDTTYYLFWYKQPPSGELV...,EAGVVQSPRYKIIEKKQPVAFWCNPISGHNTLYWYLQNLGQGPELL...
2,TRBV13*01,TRAV38-2/DV8*01,Yu,KAFSPEVIPMF,B*57:01,human_microglobulin,human,human,72,CVLCIYGNKLVF,CASSSGLAGTKTQYF,0,AQTVTQSQPEMSVQEAETVTLSCTYDTSESDYYLFWYKQPPSRQMI...,AAGVIQSPRHLIKEKRETATLKCYPIPRHDTVYWYQQGPGQDPQFL...
3,TRBV19*01,TRAV5*01,Yu,KAFSPEVIPMF,B*57:01,human_microglobulin,human,human,72,CAESGGYQKVTF,CATTGSYGYTF,1,GEDVEQSLFLSVREGDSSVINCTYTDSSSTYLYWYKQEPGAGLQLL...,DGGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLI...
4,TRBV10-3*01,TRAV21*01,Yu,KAFSPEVIPMF,B*57:01,human_microglobulin,human,human,72,CAVRPDSGNTGKLIF,CAIRQAGTEAFF,0,KQEVTQIPAALSVPEGENLVLNCSFTDSAIYNLQWFRQDPGKGLTS...,DAGITQSPRHKVTETGTPVTLRCHQTENHRYMYWYRQDPGHGLRLI...
5,TRBV11-3*01,TRAV39*01,Yu,KAFSPEVIPMF,B*57:01,human_microglobulin,human,human,72,CAVATNAGNMLTF,CASSRGFGREHTEAFF,0,ELKVEQNPLFLSMQEGKNYTIYCNYSTTSDRLYWYRQDPGKSLESL...,EAGVVQSPRYKIIEKKQPVAFWCNPISGHNTLYWYLQNLGQGPELL...
6,TRBV14*01,TRAV13-1*01,Schinkelshoek,ERNAGSGIIISDT,DQA1*01:02,DQB1*06:02,human,human,211,CAASARGGADGLTF,CASSLPGTSTNEKLFF,0,GENVEQHPSTLSVQEGDSAVIKCTYSDSASNYFPWYKQELGKGPQL...,EAGVTQFPSHSVIEKGQTVTLRCDPISGHDNLYWYRRVMGKEIKFL...
7,TRBV19*01,TRAV29/DV5*01,Schinkelshoek,ERNAGSGIIISDT,DQA1*01:02,DQB1*06:02,human,human,211,CAASSDTGANSKLTF,CASSMGQANTEAFF,0,DQQVKQNSPSLSVQEGRISILNCDYTNSMFDYFLWYKKYPAEGPTF...,DGGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLI...
8,TRBV11-3*01,TRAV12-2*01,Schinkelshoek,ERNAGSGIIISDT,DQA1*01:02,DQB1*06:02,human,human,211,CAVKIGGFQKLVF,CASSLVDRGEQFF,0,QKEVEQNSGPLSVPEGAIASLNCTYSDRGSQSFFWYRQYSGKSPEL...,EAGVVQSPRYKIIEKKQPVAFWCNPISGHNTLYWYLQNLGQGPELL...
9,TRBV4-3*01,TRAV17*01,Schinkelshoek,ERNAGSGIIISDT,DQA1*01:02,DQB1*06:02,human,human,211,CATASYNTDKLIF,CASSRGTAATNEKLF,1,SQQGEEDPQALSIQEGENATMNCSYKTSINNLQWYRQNSGRGLVHL...,ETGVTQTPRHLVMGMTNKKSLKCEQHLGHNAMYWYKQSAKKPLELM...


In [35]:
validation_hc1 = validation_dataset[(validation_dataset['pMHC_SPECIES']=="human") & (validation_dataset['mhcb']=="human_microglobulin")]
validation_hc2 = validation_dataset[(validation_dataset['pMHC_SPECIES']=="human") & (validation_dataset['mhcb']!="human_microglobulin")]
validation_mc1 = validation_dataset[(validation_dataset['pMHC_SPECIES']=="mouse") & (validation_dataset['mhcb']=="mouse_microglobulin")]
validation_mc2 = validation_dataset[(validation_dataset['pMHC_SPECIES']=="mouse") & (validation_dataset['mhcb']!="mouse_microglobulin")]
print("human c1 validation pairs: "+str(len(validation_hc1)))
print("human c2 validation pairs: "+str(len(validation_hc2)))
print("mouse c1 validation pairs: "+str(len(validation_mc1)))
print("mouse c2 validation pairs: "+str(len(validation_mc2)))

human c1 validation pairs: 2544
human c2 validation pairs: 2304
mouse c1 validation pairs: 270
mouse c2 validation pairs: 150


In [36]:
device = torch.device("cuda:"+str(CUDA) if torch.cuda.is_available() else "cpu")
print("used device is:" + str(device))

used device is:cuda:0


In [37]:
CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel = set_model(device)
#CLmodel2, pMHCmodel2, vGdVAEamodel2, vGdVAEbmodel2, cdr3VAEamodel2, cdr3VAEbmodel2 = set_model(device)
#CLmodel3, pMHCmodel3, vGdVAEamodel3, vGdVAEbmodel3, cdr3VAEamodel3, cdr3VAEbmodel3 = set_model(device)

In [38]:
AdamOptimizer = torch.optim.Adam(CLmodel.parameters(), lr=LR)

In [39]:
time_tmp1 = time.time()
val_metric = val(validation_dataset, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
print("Before training, metric of validation dataset is: " + str(val_metric) + ", time elapsed: " + str(time.time()-time_tmp1))

Before training, metric of validation dataset is: 1.8588364933509218, time elapsed: 28.795634269714355


In [40]:
time_tmp1 = time.time()
val_metric = val(validation_dataset, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
print("Before training, metric of validation dataset is: " + str(val_metric) + ", time elapsed: " + str(time.time()-time_tmp1))

Before training, metric of validation dataset is: 1.8525143582076853, time elapsed: 24.48460602760315


In [41]:
val_hc1_metric = val(validation_hc1, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
val_hc2_metric = val(validation_hc2, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
val_mc1_metric = val(validation_mc1, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
val_mc2_metric = val(validation_mc2, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
print("Before training, metric of validation dataset is: h c1:{:.5f}, h c2:{:.5f}, m c1:{:.5f}, m c2:{:.5f}".format(val_hc1_metric,val_hc2_metric,val_mc1_metric,val_mc2_metric))

Before training, metric of validation dataset is: h c1:1.90790, h c2:1.78276, m c1:2.02987, m c2:1.73645


In [42]:
val_hc1_metric = val(validation_hc1, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
val_hc2_metric = val(validation_hc2, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
val_mc1_metric = val(validation_mc1, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
val_mc2_metric = val(validation_mc2, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
print("Before training, metric of validation dataset is: h c1:{:.5f}, h c2:{:.5f}, m c1:{:.5f}, m c2:{:.5f}".format(val_hc1_metric,val_hc2_metric,val_mc1_metric,val_mc2_metric))

Before training, metric of validation dataset is: h c1:1.90407, h c2:1.78555, m c1:2.00835, m c2:1.77115


In [None]:
try:
    for epoch in range(0, EPOCH):
        epoch_start_time = time.time()
        
        train_loss = train(training_dataset[PAIR_PER_EPOCH*epoch:PAIR_PER_EPOCH*(epoch+1)], CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, AdamOptimizer, device)
        val_metric = val(validation_dataset, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
        val_hc1_metric = val(validation_hc1, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
        val_hc2_metric = val(validation_hc2, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
        val_mc1_metric = val(validation_mc1, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
        val_mc2_metric = val(validation_mc2, CLmodel, pMHCmodel, vGdVAEamodel, vGdVAEbmodel, cdr3VAEamodel, cdr3VAEbmodel, LossFunction2, device)
        
        print("Epoch:{}, Train Loss:{:.5f}, Validation Loss:{:.5f}, human c1:{:.5f}, human c2:{:.5f}, mouse c1:{:.5f}, mouse c2:{:.5f}, Time:{:.5f}".format(epoch+1,train_loss,val_metric,val_hc1_metric,val_hc2_metric,val_mc1_metric,val_mc2_metric,time.time()-epoch_start_time))
            
except KeyboardInterrupt:
    print('-' * 89) 
    print('Exiting from training early')