In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from typing import List, Dict, Any, Optional
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, roc_auc_score
from tqdm import tqdm

In [36]:
class model(nn.Module):

    def __init__(self, L: int = 30):
        super().__init__()

        self.embeds = nn.Embedding(num_embeddings= 20, embedding_dim= 128)
        self.dense1 = nn.Linear(in_features = 128, out_features = 64, bias = False )
        self.dropout = nn.Dropout(p = 0.1)
        self.flat = nn.Flatten()
        self.dense2 = nn.Linear(in_features = 64, out_features = 16, bias = False)
        self.batch = nn.BatchNorm1d(num_features= 16)
        self.dense3 = nn.Linear(in_features = 16, out_features = 4, bias = False)

    def forward(self, x):
        x = self.embeds(x)
        x = F.relu(self.dense1( x ) )
        x = F.relu(self.dense2(self.dropout(x) ) )
        print(x.size() )
        #x = self.flat(x)
        x = F.softmax(self.dense3(x ), dim =1 )

        return x


def train(model: nn.Module, train_loader: DataLoader, num_epochs: int, device: str = "cuda") -> List[float]:
    model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(params=model.parameters(), lr= 0.001)
    train_loss = []
    
    for epoch in tqdm(range(num_epochs)):
        model.train()
        curr_loss =  0.0

        for inputs, labels in train_loader:
            #with torch.autocast(device_type = device, dtype=torch.bfloat16):
            inputs = inputs.float().to(device)
            labels = labels.float().to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels.view(outputs.size() ))
            #print(loss.dtype, outputs.dtype )
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            curr_loss +=loss.item()
                
        train_loss.append( curr_loss/(len(train_loader)) )

    return train_loss  

def evaluate(model: nn.Module, eval_loader: DataLoader, device: str = "cuda") -> List[float]:
    model.to(device)
    eval_loss = []
    
    model.eval()
    curr_loss =  0.0
    for inputs, labels in eval_loader:
        #with torch.autocast(device_type = device, dtype=torch.bfloat16):
        inputs = inputs.float().to(device)
        labels = labels.float().to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels.view(outputs.size() ))
        #print(loss.dtype, outputs.dtype )
        
        curr_loss +=loss.item()
            
    eval_loss.append( curr_loss/(len(eval_loader)) )

    return eval_loss  

In [37]:
mdl = model(L = 30)
sum([p.numel() for p in mdl.parameters()])

11872

In [38]:
ins = torch.randint(0,20,(12,30))
mdl(ins).shape

torch.Size([12, 30, 16])


torch.Size([12, 30, 4])

In [32]:
ins[0]

tensor([16,  5,  9, 10,  2, 11, 12,  1, 10, 17, 11,  6,  4, 19, 13, 13,  3, 10,
        13, 16,  4, 12, 12,  0, 19,  6,  0, 11, 16, 10])

In [9]:
class dataset_spec(Dataset):
    def __init__(self, inputs,labels):
        self.inputs = inputs
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

def enc_list_bl_max_len(aa_seqs, blosum, max_seq_len):
    # encode sequences:
    sequences=[]
    for seq in aa_seqs:
        e_seq=np.zeros((len(seq),len(blosum["A"])))
        count=0
        for aa in seq:
            if aa in blosum:
                e_seq[count]=blosum[aa]
                count+=1
            else:
                print(aa)
                sys.stderr.write("Unknown amino acid in peptides: "+ aa +", encoding aborted!\n")
                sys.exit(2)
                
        sequences.append(e_seq)

    # pad sequences:
    #max_seq_len = max([len(x) for x in aa_seqs])
    n_seqs = len(aa_seqs)
    n_features = sequences[0].shape[1]

    enc_aa_seq = np.zeros((n_seqs, max_seq_len, n_features))
    for i in range(0,n_seqs):
        enc_aa_seq[i, :sequences[i].shape[0], :n_features] = sequences[i]

    return enc_aa_seq


blosum50_20aa = {
        'A': np.array((5,-2,-1,-2,-1,-1,-1,0,-2,-1,-2,-1,-1,-3,-1,1,0,-3,-2,0)),
        'R': np.array((-2,7,-1,-2,-4,1,0,-3,0,-4,-3,3,-2,-3,-3,-1,-1,-3,-1,-3)),
        'N': np.array((-1,-1,7,2,-2,0,0,0,1,-3,-4,0,-2,-4,-2,1,0,-4,-2,-3)),
        'D': np.array((-2,-2,2,8,-4,0,2,-1,-1,-4,-4,-1,-4,-5,-1,0,-1,-5,-3,-4)),
        'C': np.array((-1,-4,-2,-4,13,-3,-3,-3,-3,-2,-2,-3,-2,-2,-4,-1,-1,-5,-3,-1)),
        'Q': np.array((-1,1,0,0,-3,7,2,-2,1,-3,-2,2,0,-4,-1,0,-1,-1,-1,-3)),
        'E': np.array((-1,0,0,2,-3,2,6,-3,0,-4,-3,1,-2,-3,-1,-1,-1,-3,-2,-3)),
        'G': np.array((0,-3,0,-1,-3,-2,-3,8,-2,-4,-4,-2,-3,-4,-2,0,-2,-3,-3,-4)),
        'H': np.array((-2,0,1,-1,-3,1,0,-2,10,-4,-3,0,-1,-1,-2,-1,-2,-3,2,-4)),
        'I': np.array((-1,-4,-3,-4,-2,-3,-4,-4,-4,5,2,-3,2,0,-3,-3,-1,-3,-1,4)),
        'L': np.array((-2,-3,-4,-4,-2,-2,-3,-4,-3,2,5,-3,3,1,-4,-3,-1,-2,-1,1)),
        'K': np.array((-1,3,0,-1,-3,2,1,-2,0,-3,-3,6,-2,-4,-1,0,-1,-3,-2,-3)),
        'M': np.array((-1,-2,-2,-4,-2,0,-2,-3,-1,2,3,-2,7,0,-3,-2,-1,-1,0,1)),
        'F': np.array((-3,-3,-4,-5,-2,-4,-3,-4,-1,0,1,-4,0,8,-4,-3,-2,1,4,-1)),
        'P': np.array((-1,-3,-2,-1,-4,-1,-1,-2,-2,-3,-4,-1,-3,-4,10,-1,-1,-4,-3,-3)),
        'S': np.array((1,-1,1,0,-1,0,-1,0,-1,-3,-3,0,-2,-3,-1,5,2,-4,-2,-2)),
        'T': np.array((0,-1,0,-1,-1,-1,-1,-2,-2,-1,-1,-1,-1,-2,-1,2,5,-3,-2,0)),
        'W': np.array((-3,-3,-4,-5,-5,-1,-3,-3,-3,-3,-2,-3,-1,1,-4,-4,-3,15,2,-3)),
        'Y': np.array((-2,-1,-2,-3,-3,-1,-2,-3,2,-1,-1,-2,0,4,-3,-2,-2,2,8,-1)),
        'V': np.array((0,-3,-3,-4,-1,-3,-3,-4,-4,4,1,-3,1,-1,-3,-2,0,-3,-1,5))
    }

In [None]:
P_samples = int(500)
L =30; N_samples = int(P_samples)
EPOCHS, batch_size = 25, 128
repeats = int(25)
device = "cuda"

pm,beta = 0.2,1.00

peptide1, peptide2, peptide3 = 'AMFWSVPTV','GLCTLVAML','VTEHDTLLY'
check1, check2, check3 = 1500,2100,630

df1_1 = pd.read_csv(f'./Binders_{peptide1}.csv').drop_duplicates()
df2 = pd.read_csv(f'./Generated_binders_{peptide1}_BERT_pmask{pm}_beta{beta:.2f}_finetuned_wcheckpoint{check1}.csv').drop_duplicates()
df2 = df2[df2['CDR3b'].str.len()>7]
df2 = df2[~df2['CDR3b'].isin(df1_1['CDR3b'])].dropna()

data_pos_1 = pd.concat((df1_1.sample(50),df2.sample(P_samples - 50 ) )).drop_duplicates()
data_pos_1['labels'] = 0
df1_1['labels'] = 0

data_pos_2 = pd.read_csv('./Binders_GLCTLVAML.csv').drop_duplicates()
data_pos_2['labels'] = 1

df1_3 = pd.read_csv(f'./Binders_{peptide3}.csv').drop_duplicates()
df2 = pd.read_csv(f'./Generated_binders_{peptide3}_BERT_pmask{pm}_beta{beta:.2f}_finetuned_wcheckpoint{check3}.csv').drop_duplicates()
df2 = df2[df2['CDR3b'].str.len()>7]
df2 = df2[~df2['CDR3b'].isin(df1_3['CDR3b'])].dropna()

data_pos_3 = pd.concat((df1_3.sample(170),df2.sample(P_samples - 170) )).drop_duplicates()
data_pos_3['labels'] = 2
df1_3['labels'] = 2

T = int(25)

results = []
data_neg = pd.read_csv('./Background_notaligned.csv')
data_neg['labels'] = 3

print('Training... and testing over %d repeats with entry-state: ' %(repeats))

df_tchar = pd.read_csv(f'./tchard_{peptide1}.csv').drop_duplicates();
ext_eval_1 = df_tchar[~df_tchar['CDR3b'].isin(data_pos_1['CDR3b'])]
ext_eval_1['labels']=0

df_tchar = pd.read_csv(f'./tchard_{peptide2}.csv').drop_duplicates();
ext_eval_2 = df_tchar[~df_tchar['CDR3b'].isin(data_pos_2['CDR3b'])]
ext_eval_2['labels']=1

df_tchar = pd.read_csv(f'./tchard_{peptide3}.csv').drop_duplicates();
ext_eval_3 = df_tchar[~df_tchar['CDR3b'].isin(data_pos_3['CDR3b'])]
ext_eval_3['labels']=2

In [None]:
Test_insample = []; Test_ext = []

for K in range(repeats):  
    if (K == 0 or K == repeats-1):
        print('\t \t repeat n: %d' %K)
    
    data_pos_in_1 = data_pos_1.sample(P_samples)
    data_pos_in_2 = data_pos_2.sample(P_samples)
    data_pos_in_3 = data_pos_3.sample(P_samples)
    data_neg_in = data_neg.sample(N_samples)
    
    
    cdr_in = pd.concat((data_pos_in_1,data_pos_in_2,data_pos_in_3,data_neg_in ))
    cdr_in.reset_index(inplace=True,drop=True)
    y_train = cdr_in['labels']; y_train = F.one_hot(torch.tensor(y_train), num_classes = 4) 
    cdr_in = enc_list_bl_max_len(cdr_in['CDR3b'], blosum50_20aa, L)

    train_dataset = dataset_spec( torch.tensor(cdr_in.transpose(0,2,1),dtype = torch.float32 ), y_train )
    train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True )

    mdl = model(L = 30)
    train_tmp = train(mdl, train_loader, num_epochs = EPOCHS)
    with torch.no_grad():
        mdl.eval()

        pos_test_1 = df1_1[~df1_1['CDR3b'].isin(data_pos_in_1['CDR3b'])].dropna().sample(T)
        pos_test_2 = data_pos_2[~data_pos_2['CDR3b'].isin(data_pos_in_2['CDR3b'])].dropna().sample(T)
        pos_test_3 = df1_3[~df1_3['CDR3b'].isin(data_pos_in_3['CDR3b'])].dropna().sample(T)
        neg_test = data_neg[~data_neg.isin(data_neg_in)].dropna()
        cdr_test = pd.concat((pos_test_1,pos_test_2,pos_test_3,neg_test.sample(T) ))
        cdr_test.reset_index(inplace=True,drop=True)
        
        y_test = cdr_test['labels'];
        cdr_test = enc_list_bl_max_len(cdr_test['CDR3b'], blosum50_20aa, L)
        output = mdl(torch.tensor(cdr_test.transpose(0,2,1), dtype= torch.float32).to(device) )
        auc = roc_auc_score(y_score= output.detach().cpu().numpy(), y_true = y_test, multi_class= 'ovo')  
        acc = accuracy_score(y_pred= np.argmax(output.detach().cpu().numpy(),axis = 1), y_true = y_test)        
        Test_insample.append([auc, acc])
        
        neg_test = data_neg[~data_neg.isin(data_neg_in)].dropna().sample(2*T)
        cdr_test_ext = pd.concat((ext_eval_1.sample(2*T),ext_eval_2.sample(2*T),ext_eval_3.sample(2*T),neg_test ))
        cdr_test_ext.reset_index(inplace=True,drop=True)
        y_test_ext = cdr_test_ext['labels']; #y_test_ext = F.one_hot(torch.tensor(y_test_ext.astype(np.int64)), num_classes = 4)
        cdr_test_ext = enc_list_bl_max_len(cdr_test_ext['CDR3b'], blosum50_20aa, L)
        
        output = mdl(torch.tensor(cdr_test_ext.transpose(0,2,1), dtype= torch.float32).to(device) )
        auc = roc_auc_score(y_score= output.detach().cpu().numpy(), y_true = y_test_ext, multi_class= 'ovo')  
        acc = accuracy_score(y_pred= np.argmax(output.detach().cpu().numpy(),axis = 1), y_true = y_test_ext) 
        print(acc)
        Test_ext.append([auc, acc])

    del mdl




In [None]:
accuracy_score(y_pred= np.argmax(output.detach().cpu().numpy(),axis = 1), y_true = y_test_ext) 

In [None]:
Tback_n=np.mean(np.array(Test_insample),axis=0); ETback_n= np.std(np.array(Test_insample),axis=0)
Text_n=np.mean(np.array(Test_ext),axis=0); EText_n= np.std(np.array(Test_ext),axis=0)

    
results.append(np.concatenate(([P_samples],[N_samples],Tback_n,ETback_n,Text_n,EText_n )) )

columns_name = ['Psamples','Nsamples','AUC-back','ACC-back','std-back','stdACC-back','AUC-ext','ACC-ext','std-ext','stdACC-ext']

df = pd.DataFrame(results,columns=columns_name)

In [None]:
df