# 1.Original model prediction

In [None]:
import os
import torch
import numpy as np
import csv
from predictor import utils
from pandas import read_csv, DataFrame
from preprocessing.prep import get_protseqs_ntseqs,determine_tcr_seq_nt,determine_tcr_seq_vj
from LM.bert_mdl import retrieve_model, compute_embs

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def Original_model_prediction(dataPath='../data/pre_TCRconv_small.csv', mode= 'prediction',\
               epitope_labels= './data/unique_epitopes_test.npy', \
               chains='B', h_cdr31='CDR3B', h_long1='LongB', model_file='./model/statedict_vdjdb-b-small.pt',  \
               embedfile1= 'embeddings/bert_vdjdb-b-example_0.bin',\
               predfile='outputs/preds-b-example.csv', batch_size=256):

    p = {      
        'dataset': dataPath,
        'chains': chains,
        'epitope_labels': epitope_labels,
        'embedtype': 'cdr3+context',
        'h_cdr31': h_cdr31,
        'h_long1': h_long1,
        'embedfile1': embedfile1,
        'h_cdr32': 'None',
        'h_long2': 'None',    
        'embedfile2': 'None',
        'delimiter': ',',
        
        'mode': mode, 
        'folds': 'None',
        'fold_num': 0,  
        'h_epitope': 'Epitope',

        'batch_size': 512,
        'betas': [0.9, 0.999],
        'use_pos_weight': True,
        'dropouts': [0.1, 0.1],
        'iters_adam': 2500,
        'lr_conv': 0.0002,
        'lr_linear': 0.01,
        'T_anneal': 3000,
        'iters_swa': 500,
        'anneal_strategy': 'cos',
        'lr_swa': 0.0001,
        'T_anneal_swa': 300,
        
        'resultfile': './outputs/results.tsv',
        'print_every': 100,
        'lossfile': './outputs/loss_train.tsv',
        'lossfile_test': './outputs/loss_test.tsv',
        # 'model_folder': model_folder,
        'params_to_print': ['fold_num'],

        'binary': False,
        'binary_label': None,
        
        'append_oh': False,
        'kernel_sizes': [5,9,15,21,3],
        'pool': 'max',
        'h_v1': 'none',
        'h_j1': 'none',
        'h_v2': 'none',
        'h_j2': 'none',
        'h_nt1': 'none',
        'h_nt2': 'none',
        'guess_allele01': True,
        'model_file': model_file,  
        'use_LM': False,
        'num_features': 1024,
        'input_type': 'tcr+cdr3',
        'predfile': predfile,  
        'additional_columns': [],
        'decimals': 4,
        
        }

    p['two_chains']=len(p['chains'])>1
    p['table']=(p['params_to_print'],[p[h] for h in p['params_to_print']])
    p['save_intermediate'] = p['lossfile'].lower()!='none' or p['lossfile_test'].lower()!='none'
    
    data = read_csv(p['dataset'],delimiter=p['delimiter'],dtype=str,keep_default_na=False)
    epis_u = np.load(p['epitope_labels'])
    n_labels=len(epis_u)
    n_chains=len(p['chains'])
    
    resfile = p['predfile']
    with open(resfile,'w') as f:
        f.write(p['delimiter'].join(['TCR'+c for c in p['chains']]
                +p['additional_columns']+list(epis_u))+'\n')
    
    # Load TCRconv model
    model = utils.load_model(p['model_file'],p,p['num_features'],n_labels,device)
    # embeddings / embedding-models
    if p['use_LM']:
        LM=retrieve_model().to(device)
    
    # Get requested sequences for genes
    geneseqs={}
    for ic,chain in enumerate(p['chains']):
        c=str(ic+1)
        if p['input_type']=='cdr3+nt':
            geneseqs['protV'+c],geneseqs['protJ'+c],geneseqs['ntV'+c],geneseqs['ntJ'+c] = \
                    get_protseqs_ntseqs(chain=chain)
        elif p['input_type']=='cdr3+vj':
            geneseqs['protV'+c],geneseqs['protJ'+c],_,_ = get_protseqs_ntseqs(chain=chain)
    
    I=[]
    ts_all = [[] for c in p['chains']] # separate list for each chain
    icount,i0 = 0,0
    imax=len(data)-1
    
    for i in range(len(data)):
    
        if p['input_type']=='cdr3+nt':
            tcr12 = []
            for ic in range(n_chains):
                c=str(ic+1)
                t,_,_ = determine_tcr_seq_nt(data[p['h_nt'+c]][i],data[p['h_cdr3'+c]][i],geneseqs['protV'+c],
                        geneseqs['protJ'+c],geneseqs['ntV'+c],geneseqs['ntJ'+c],guess01=p['guess_allele01'])
                tcr12.append(t)
        elif p['input_type']=='cdr3+vj':
            tcr12 = []
            for ic in range(n_chains):
                c=str(ic+1)
                t = determine_tcr_seq_vj(data[p['h_cdr3'+c]][i],data[p['h_v'+c]][i],data[p['h_j'+c]][i],
                    geneseqs['protV'+c],geneseqs['protJ'+c],guess01=p['guess_allele01'])
                tcr12.append(t)
        elif p['input_type']=='cdr3':
            tcr12 = [data[p['h_cdr3'+str(ic+1)]][i] for ic in range(n_chains)]
        else: # tcr+cdr3 / tcr
            tcr12 = [data[p['h_long'+str(ic+1)]][i] for ic in range(n_chains)]
    
    
        # Check if (either) sequence is empty
        if np.any([t=='' for t in tcr12]):
            I.append(False)
            for ic in range(n_chains):
                ts_all[ic].append(tcr12[ic])
        else:
            I.append(True)
            for ic in range(n_chains):
                ts_all[ic].append(tcr12[ic])
            icount+=1
    
    
        if icount==p['batch_size'] or i==imax:
            ts_all =[np.array(t) for t in ts_all]
            if icount>0:
                I= np.array(I,dtype=bool)
                if p['use_LM']: # If LM is used, compute embeddings
                    cdr3s = data[p['h_cdr31']][i0:i+1][I].values
                    print('computing embeddings: {:d}-{:d}/{:d}'.format(i0,i,imax))
                    if p['embedtype']=='cdr3+context':
                        embeddings1 = compute_embs(LM, ts_all[0][I], cdr3s)
                        embeddings1 = utils.stack_embeddings(embeddings1,device,p['append_oh'],cdr3s)
                    else:
                        embeddings1 = compute_embs(LM, ts_all[0][I], None)
                        embeddings1 = utils.stack_embeddings(embeddings1,device,p['append_oh'])
    
                    if p['two_chains']:
                        cdr3s = data[p['h_cdr32']][i0:i+1][I].values
                        if p['embedtype']=='cdr3+context':
                            embeddings2 = compute_embs(LM,ts_all[1][I], cdr3s)
                            embeddings2 = utils.stack_embeddings(embeddings2,device,p['append_oh'],cdr3s)
                        else:
                            embeddings2 = compute_embs(LM, ts_all[1][I], None)
                            embeddings2 = utils.stack_embeddings(embeddings2,device,p['append_oh'])
    
                else: # 1-2 embedding dictionaries are used
    
                    #print(ts_all[0][I])
                    cdr3s = data[p['h_cdr31']][i0:i+1][I].values
                    cdr3max = utils.maxlen(cdr3s)
                    embeddings1 = utils.get_embeddings(utils.get_embedding_dict(p['embedfile1']),
                                cdr3s, ts_all[0][I], cdr3max, device, p['append_oh'])
                    if p['two_chains']:
                        cdr3s = data[p['h_cdr32']][i0:i+1][I].values
                        cdr3max = utils.maxlen(cdr3s)
                        embeddings2 = utils.get_embeddings(utils.get_embedding_dict(p['embedfile2']),
                                cdr3s,ts_all[1][I], cdr3max, device, p['append_oh'])
    
                # Predictions
                if p['two_chains']:
                    output = model(embeddings1,embeddings2).detach().cpu().numpy()
                else:
                    output = model(embeddings1).detach().cpu().numpy()
                output = utils.toprob(output)
    
                pred_ar = np.ones((len(I),n_labels),dtype=float)*np.nan
                pred_ar[I,:]= output
    
            else: # No proper sequences were found, add fillers
                pred_ar = np.ones((len(I),n_labels),dtype=float)*np.nan
    
            # append results to result file
            df = DataFrame(np.concatenate([np.expand_dims(ts_all[i],1) for i in range(len(ts_all))] \
                +[np.expand_dims(data[col].values[i0:i+1],1) for col in p['additional_columns']] \
                +[np.round(pred_ar,p['decimals'])],axis=1))
            df.to_csv(resfile,sep=p['delimiter'],mode='a',header=False,index=False)
    
            I = []
            ts_all = [[] for c in p['chains']]
            icount = 0
            i0 = i+1    

In [None]:
import pandas as pd
def fix(data_ori,data_new):
    data=pd.read_csv(data_ori)
    #'CDR3.beta', 'antigen_epitope','mhc.a','label','negative.source','license'
    data.rename(columns={'CDR3B':'CDR3.beta','Epitope':'antigen_epitope','MHC':'mhc.a','Affinity':'label'},inplace=True)
    df=data[['CDR3.beta', 'antigen_epitope','mhc.a','label']]
    
    df_epi = pd.DataFrame([peptides.Peptide(s).descriptors() for s in df.antigen_epitope])
    df_epi.columns='epitope_'+df_epi.columns
    df_cdrb = pd.DataFrame([peptides.Peptide(s).descriptors() for s in df['CDR3.beta']])
    df_cdrb.columns='cdr3_'+df_cdrb.columns
    df=pd.concat([df, df_cdrb, df_epi],axis=1)
    df.to_csv(data_new)
        
preData_ori="../data/test_CDR3B_others.csv"
preData=f"../data/pre_TCRconv_small.csv"
fix(preData_ori,preData)

In [None]:
import numpy as np
from preprocessing import prep 
filename = '../data/pre_TCRconv_small.csv'
epis = np.loadtxt(filename,usecols=(0),unpack=True,delimiter=',',skiprows=1,comments=None,dtype='str')
epis_u,labels = prep.get_labels(epis)
os.makedirs('./data/', exist_ok=True)
np.save('./data/unique_epitopes_test.npy',epis_u)

In [None]:
Original_model_prediction(dataPath='../data/pre_TCRconv_small.csv', mode= 'prediction',\
               epitope_labels= './data/unique_epitopes_test.npy', \
               chains='B', h_cdr31='CDR3B', h_long1='LongB', model_file='../Original_model/TCRconv_small.pt', \
               embedfile1= 'embeddings/bert_vdjdb-b-small.bin',\
               predfile="../result_path/Original_model_prediction/test.csv', batch_size=256
               )

# 2.Model retraining: fullA

In [None]:

import os
os.chdir('./tcrconv/')
from os import path, mkdir

sys.path.append('./LM/')
import pandas as pd
from preprocessing import prep
import os
import torch
import numpy as np
import csv
from predictor import utils
from pandas import read_csv, DataFrame
from preprocessing.prep import get_protseqs_ntseqs,determine_tcr_seq_nt,determine_tcr_seq_vj
from LM.bert_mdl import retrieve_model, compute_embs
from argparse import ArgumentParser,ArgumentDefaultsHelpFormatter,ArgumentTypeError
from bert_mdl import retrieve_model, extract_and_save_embeddings
import torch
import os

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

def fix(data_ori,data_new):
    data=pd.read_csv(data_ori)
    #'CDR3.beta', 'antigen_epitope','mhc.a','label','negative.source','license'
    # data.drop(columns=['VB','JB'],inplace=True)
    data.rename(columns={'TRBV':'VB','TRBJ':'JB','TRAV':'VA','TRAJ':'JA','Affinity':'label'},inplace=True)  
    data=data.loc[data.label==1]
    data['Subject']='PMID:0'
    df=data[['Epitope', 'Subject', 'CDR3A', 'VA', 'JA', 'LongA']]
    df.to_csv(data_new,index=False)
        
def extract_epitopes(test_file,epitope_file):
    epis = np.loadtxt(test_file,usecols=(0),unpack=True,delimiter=',',skiprows=1,comments=None,dtype='str')
    epis_u,labels = prep.get_labels(epis)
    np.save(epitope_file,epis_u)


def make_embeddings(testfile, embed_model):
    print("start embedding")
    model = retrieve_model().to(device)
    # embed_model='embeddings/bert_vdjdb-b-example'#name for the model. Will be used if the model or results are saved
    # testfile='./data/vdjdb-b-example.csv' #filename of the used dataset
    delimiter=',' #Column delimiter in dataset file
    h_cdr3='CDR3A'#Column name for CDR3 of chain 1 in dataset file
    h_long='LongA'#Column name for Long TCR-sequence of chain 1 in dataset file
    seqs_per_file=50000#Maximum number of sequences in one embedding file. If there are more sequences, the embeddings will be split into several files.
    
    # extract and save some embeddings
    extract_and_save_embeddings(model, data_f_n=testfile, sequence_col=h_long, cdr3_col=h_cdr3, seqs_per_file=seqs_per_file, emb_name=embed_model,separator=delimiter)
    print("finished embedding")

def train_and_save(modelName='vdjdb-b-example', dataPath="./data/vdjdb-b-example.csv", \
               epitope_labels= './data/unique_epitopes_vdjdb-b-example.npy', mode= 'train', \
               chains='A', h_cdr32='CDR3A', h_long2='LongA', \
               embedfile2= 'embeddings/bert_vdjdb-b-example_0.bin',\
               model_folder='models_retrained', save_model_path='models_retrained/model_example.pt'):
    p = {
        
        'dataset': dataPath,
        'chains': chains,
        'epitope_labels': epitope_labels,
        'embedtype': 'cdr3+context',
        'h_cdr31': h_cdr32,
        'h_long1': h_long2,
        'embedfile1': embedfile2,
        'h_cdr32': 'None',
        'h_long2': 'None',    
        'embedfile2': 'None',
        'delimiter': ',',
        
        'mode': mode, 
        'name': modelName,
        'folds': 'None',
        'fold_num': 0,  
        'h_epitope': 'Epitope',

        'batch_size': 512,
        'betas': [0.9, 0.999],
        'use_pos_weight': True,
        'dropouts': [0.1, 0.1],
        'iters_adam': 2500,
        'lr_conv': 0.0002,
        'lr_linear': 0.01,
        'T_anneal': 3000,
        'iters_swa': 500,
        'anneal_strategy': 'cos',
        'lr_swa': 0.0001,
        'T_anneal_swa': 300,
        
        'resultfile': None,
        'print_every': 100,
        'lossfile': 'None',#'./outputs/loss_train.tsv',
        'lossfile_test': 'None',
        'model_folder': model_folder,
        'params_to_print': ['name', 'fold_num'],

        'binary': False,
        'binary_label': None,
        
        'append_oh': False,
        'kernel_sizes': [5,9,15,21,3],
        'pool': 'max',
            
        
    }
    
    p['two_chains']=len(p['chains'])>1
    p['table']=(p['params_to_print'],[p[h] for h in p['params_to_print']])
    p['save_intermediate'] = p['lossfile'].lower()!='none' or p['lossfile_test'].lower()!='none'
    
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # print(f"Using device: {device}")
    
    mode =p['mode'].lower()
    # Get dataloaders
    if mode=='cv':
        usetypes=['train','test']
    elif mode=='train':
        usetypes=['train']
        
    print('loading data...')
    loader, n_categories, n_feat = utils.get_dataloaders(p,device=device,usetypes=usetypes)
    # Create model
    print('creating model...')
    model = utils.construct_model(p,n_feat,n_categories['train'],device)
    # Create requested loss and result files if they don't exist yet
    # utils.create_resultfiles(p,n_categories)
    # Model training
    print('training model...')
    model = utils.iterate_model_batches_swa(model,loader,p,device)
    print('saving model...')
    torch.save(model.state_dict(), save_model_path)
    # Save model if model_folder is given
    # if p['model_folder'] != 'None':
    #     if not path.isdir(p['model_folder']):
    #         mkdir(p['model_folder'])
    #     modelfile = 'statedict_'+p['name'] + ('_'+str(p['fold_num']))*(p['mode']=='cv') + '.pt'
    #     torch.save(model.state_dict(), p['model_folder']+'/'+modelfile)
    
    # # Make predictions and save results
    # if p['resultfile'] != 'None':
    #     y_score, labels = utils.get_yscore(model,loader['test'],useAB=p['two_chains'])
    #     y_score=y_score[:,:n_categories['test']]
    #     utils.save_results(y_score,labels,p,n_categories['test'])


In [None]:
trainfile_path ="../data/train_CDR3B_others.csv"
trainfile_out=="../data/TCRconv-A/"
os.makedirs(trainfile_out, exist_ok=True)

trainfile_new=trainfile_out+"train_CDR3B_others.csv"
fix(trainfile_path,trainfile_new)
epitope_file=trainfile_out+'train_unique_epitopes.npy'
extract_epitopes(trainfile_new,epitope_file)

embed_model=f'./embeddings/train/TCRconv_fullA/bert'+
os.makedirs(f'./embeddings/train/TCRconv_fullA/', exist_ok=True)
make_embeddings(trainfile_new, embed_model)
save_dir="../result_path/Retraining_model_prediction"
os.makedirs(save_dir,exist_ok=True)
save_model_path= "../Retraining_model/Retraining_model.pt"
Model_retraining(dataPath=trainfile_new, \
               epitope_labels= epitope_file, mode= 'train', \
               chains='A', h_cdr32='CDR3A', h_long2='LongA', \
               embedfile2= embed_model+'_0.bin',\
               save_model_path=save_model_path)


# 3.Retraining_model_prediction

In [None]:
sys.path.append('./LM/')
import pandas as pd
from preprocessing import prep
import os
import torch
import numpy as np
import csv
from predictor import utils
from pandas import read_csv, DataFrame
from preprocessing.prep import get_protseqs_ntseqs,determine_tcr_seq_nt,determine_tcr_seq_vj
from LM.bert_mdl import retrieve_model, compute_embs
from argparse import ArgumentParser,ArgumentDefaultsHelpFormatter,ArgumentTypeError
from bert_mdl import retrieve_model, extract_and_save_embeddings
import torch
import os

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

def fix(data_ori,data_new):
    data=pd.read_csv(data_ori)
    #'CDR3.beta', 'antigen_epitope','mhc.a','label','negative.source','license'
    data.drop(columns=['VB','JB'],inplace=True)
    data.rename(columns={'TRBV':'VB','TRBJ':'JB','TRAV':'VA','TRAJ':'JA','Affinity':'label'},inplace=True)
    
    data['Subject']='PMID:0'
    df=data[['Epitope', 'Subject', 'CDR3B', 'VB', 'JB', 'CDR3A', 'VA', 'JA', 'LongB','LongA']]
    df.to_csv(data_new,index=False)
        
def extract_epitopes(test_file,epitope_file):
    epis = np.loadtxt(test_file,usecols=(0),unpack=True,delimiter=',',skiprows=1,comments=None,dtype='str')
    epis_u,labels = prep.get_labels(epis)
    np.save(epitope_file,epis_u)


def make_embeddings(testfile, embed_model):
    print("start embedding")
    model = retrieve_model().to(device)
    # embed_model='embeddings/bert_vdjdb-b-example'#name for the model. Will be used if the model or results are saved
    # testfile='./data/vdjdb-b-example.csv' #filename of the used dataset
    delimiter=',' #Column delimiter in dataset file
    h_cdr3='CDR3B'#Column name for CDR3 of chain 1 in dataset file
    h_long='LongB'#Column name for Long TCR-sequence of chain 1 in dataset file
    seqs_per_file=50000#Maximum number of sequences in one embedding file. If there are more sequences, the embeddings will be split into several files.
    
    # extract and save some embeddings
    extract_and_save_embeddings(model, data_f_n=testfile, sequence_col=h_long, cdr3_col=h_cdr3, seqs_per_file=seqs_per_file, emb_name=embed_model,separator=delimiter)
    print("finished embedding")

def make_embeddings_a(testfile, embed_model):
    print("start embedding")
    model = retrieve_model().to(device)
    # embed_model='embeddings/bert_vdjdb-b-example'#name for the model. Will be used if the model or results are saved
    # testfile='./data/vdjdb-b-example.csv' #filename of the used dataset
    delimiter=',' #Column delimiter in dataset file
    h_cdr3='CDR3A'#Column name for CDR3 of chain 1 in dataset file
    h_long='LongA'#Column name for Long TCR-sequence of chain 1 in dataset file
    seqs_per_file=50000#Maximum number of sequences in one embedding file. If there are more sequences, the embeddings will be split into several files.
    
    # extract and save some embeddings
    extract_and_save_embeddings(model, data_f_n=testfile, sequence_col=h_long, cdr3_col=h_cdr3, seqs_per_file=seqs_per_file, emb_name=embed_model,separator=delimiter)
    print("finished embedding")
    
def predict_and_save(dataPath='./data/vdjdb-b-example.csv', mode= 'prediction',\
               epitope_labels= './data/unique_epitopes_vdjdb-b-example.npy', \
               chains='AB', h_cdr31='CDR3B', h_long1='LongB',  h_cdr32='CDR3A', h_long2='LongA',  h_v1='VB', h_j1='JB',h_v2='VA', h_j2='JA', model_file='./model/statedict_vdjdb-b-small.pt',  \
               embedfile1= 'embeddings/bert_vdjdb-b-example_0.bin',  embedfile2= 'embeddings/bert_vdjdb-b-example_0.bin',\
               predfile='outputs/preds-b-example.csv', result_path='outputs/b-example_probability.csv', batch_size=256,testfile=None):

    p = {      
        'dataset': dataPath,
        'chains': chains,
        'epitope_labels': epitope_labels,
        'embedtype': 'cdr3+context',
        'h_cdr31': h_cdr31,
        'h_long1': h_long1,
        'embedfile1': embedfile1,
        'h_cdr32':h_cdr32,
        'h_long2':h_long2,    
        'embedfile2': embedfile2,
        'delimiter': ',',
        
        'mode': mode, 
        # 'name': modelName,
        'folds': 'None',
        'fold_num': 0,  
        'h_epitope': 'Epitope',

        'batch_size': 512,
        'betas': [0.9, 0.999],
        'use_pos_weight': True,
        'dropouts': [0.1, 0.1],
        'iters_adam': 2500,
        'lr_conv': 0.0002,
        'lr_linear': 0.01,
        'T_anneal': 3000,
        'iters_swa': 500,
        'anneal_strategy': 'cos',
        'lr_swa': 0.0001,
        'T_anneal_swa': 300,
        
        'resultfile': './outputs/results.tsv',
        'print_every': 100,
        'lossfile': './outputs/loss_train.tsv',
        'lossfile_test': './outputs/loss_test.tsv',
        # 'model_folder': model_folder,
        'params_to_print': ['fold_num'],

        'binary': False,
        'binary_label': None,
        
        'append_oh': False,
        'kernel_sizes': [5,9,15,21,3],
        'pool': 'max',
        'h_v1': h_v1,
        'h_j1': h_j1,
        'h_v2': h_v2,
        'h_j2': h_j2,
        'h_nt1': 'none',
        'h_nt2': 'none',
        'guess_allele01': True,
        'model_file': model_file,  
        'use_LM': False,
        'num_features': 1024,
        'input_type': 'tcr+cdr3',
        'predfile': predfile,  
        'additional_columns': [],
        'decimals': 4,
        
        }

    p['two_chains']=len(p['chains'])>1
    p['table']=(p['params_to_print'],[p[h] for h in p['params_to_print']])
    p['save_intermediate'] = p['lossfile'].lower()!='none' or p['lossfile_test'].lower()!='none'
    
    data = read_csv(p['dataset'],delimiter=p['delimiter'],dtype=str,keep_default_na=False)
    epis_u = np.load(p['epitope_labels'])
    n_labels=len(epis_u)
    n_chains=len(p['chains'])
    
    resfile = p['predfile']
    with open(resfile,'w') as f:
        f.write(p['delimiter'].join(['TCR'+c for c in p['chains']]
                +p['additional_columns']+list(epis_u))+'\n')
    
    # Load TCRconv model
    model = utils.load_model(p['model_file'],p,p['num_features'],n_labels,device)
    # embeddings / embedding-models
    if p['use_LM']:
        LM=retrieve_model().to(device)
    
    # Get requested sequences for genes
    geneseqs={}
    for ic,chain in enumerate(p['chains']):
        c=str(ic+1)
        if p['input_type']=='cdr3+nt':
            geneseqs['protV'+c],geneseqs['protJ'+c],geneseqs['ntV'+c],geneseqs['ntJ'+c] = \
                    get_protseqs_ntseqs(chain=chain)
        elif p['input_type']=='cdr3+vj':
            geneseqs['protV'+c],geneseqs['protJ'+c],_,_ = get_protseqs_ntseqs(chain=chain)
    
    I=[]
    ts_all = [[] for c in p['chains']] # separate list for each chain
    icount,i0 = 0,0
    imax=len(data)-1
    
    for i in range(len(data)):
    
        if p['input_type']=='cdr3+nt':
            tcr12 = []
            for ic in range(n_chains):
                c=str(ic+1)
                t,_,_ = determine_tcr_seq_nt(data[p['h_nt'+c]][i],data[p['h_cdr3'+c]][i],geneseqs['protV'+c],
                        geneseqs['protJ'+c],geneseqs['ntV'+c],geneseqs['ntJ'+c],guess01=p['guess_allele01'])
                tcr12.append(t)
        elif p['input_type']=='cdr3+vj':
            tcr12 = []
            for ic in range(n_chains):
                c=str(ic+1)
                t = determine_tcr_seq_vj(data[p['h_cdr3'+c]][i],data[p['h_v'+c]][i],data[p['h_j'+c]][i],
                    geneseqs['protV'+c],geneseqs['protJ'+c],guess01=p['guess_allele01'])
                tcr12.append(t)
        elif p['input_type']=='cdr3':
            tcr12 = [data[p['h_cdr3'+str(ic+1)]][i] for ic in range(n_chains)]
        else: # tcr+cdr3 / tcr
            tcr12 = [data[p['h_long'+str(ic+1)]][i] for ic in range(n_chains)]
    
    
        # Check if (either) sequence is empty
        if np.any([t=='' for t in tcr12]):
            I.append(False)
            for ic in range(n_chains):
                ts_all[ic].append(tcr12[ic])
        else:
            I.append(True)
            for ic in range(n_chains):
                ts_all[ic].append(tcr12[ic])
            icount+=1
    
    
        if icount==p['batch_size'] or i==imax:
            ts_all =[np.array(t) for t in ts_all]
            if icount>0:
                I= np.array(I,dtype=bool)
                if p['use_LM']: # If LM is used, compute embeddings
                    cdr3s = data[p['h_cdr31']][i0:i+1][I].values
                    print('computing embeddings: {:d}-{:d}/{:d}'.format(i0,i,imax))
                    if p['embedtype']=='cdr3+context':
                        embeddings1 = compute_embs(LM, ts_all[0][I], cdr3s)
                        embeddings1 = utils.stack_embeddings(embeddings1,device,p['append_oh'],cdr3s)
                    else:
                        embeddings1 = compute_embs(LM, ts_all[0][I], None)
                        embeddings1 = utils.stack_embeddings(embeddings1,device,p['append_oh'])
    
                    if p['two_chains']:
                        cdr3s = data[p['h_cdr32']][i0:i+1][I].values
                        if p['embedtype']=='cdr3+context':
                            embeddings2 = compute_embs(LM,ts_all[1][I], cdr3s)
                            embeddings2 = utils.stack_embeddings(embeddings2,device,p['append_oh'],cdr3s)
                        else:
                            embeddings2 = compute_embs(LM, ts_all[1][I], None)
                            embeddings2 = utils.stack_embeddings(embeddings2,device,p['append_oh'])
    
                else: # 1-2 embedding dictionaries are used
    
                    #print(ts_all[0][I])
                    cdr3s = data[p['h_cdr31']][i0:i+1][I].values
                    cdr3max = utils.maxlen(cdr3s)
                    embeddings1 = utils.get_embeddings(utils.get_embedding_dict(p['embedfile1']),
                                cdr3s, ts_all[0][I], cdr3max, device, p['append_oh'])
                    if p['two_chains']:
                        cdr3s = data[p['h_cdr32']][i0:i+1][I].values
                        cdr3max = utils.maxlen(cdr3s)
                        embeddings2 = utils.get_embeddings(utils.get_embedding_dict(p['embedfile2']),
                                cdr3s,ts_all[1][I], cdr3max, device, p['append_oh'])
    
                # Predictions
                if p['two_chains']:
                    output = model(embeddings1,embeddings2).detach().cpu().numpy()
                else:
                    output = model(embeddings1).detach().cpu().numpy()
                output = utils.toprob(output)
    
                pred_ar = np.ones((len(I),n_labels),dtype=float)*np.nan
                pred_ar[I,:]= output
    
            else: # No proper sequences were found, add fillers
                pred_ar = np.ones((len(I),n_labels),dtype=float)*np.nan
    
            # append results to result file
            df = DataFrame(np.concatenate([np.expand_dims(ts_all[i],1) for i in range(len(ts_all))] \
                +[np.expand_dims(data[col].values[i0:i+1],1) for col in p['additional_columns']] \
                +[np.round(pred_ar,p['decimals'])],axis=1))
            df.to_csv(resfile,sep=p['delimiter'],mode='a',header=False,index=False)
           
            I = []
            ts_all = [[] for c in p['chains']]
            icount = 0
            i0 = i+1  
            
    #process result file
    df_result=pd.read_csv(resfile)
    probability = df_result[['TCRB']]
    
    df_ori=pd.read_csv(dataPath)
    print('df_ori',df_ori)
    for k in range(len(df_ori)):
        probability.loc[probability.TCRB==df_ori['LongA'][k],'Epitope']=df_ori['Epitope'][k]
    print(probability)
    get_label=df_result[df_result.columns[1:]]
    get_label = get_label.apply(pd.to_numeric, errors='coerce')

    pre_epi_ls=get_label.idxmax(axis=1).tolist()   
    probability['Epitope_pred']=pre_epi_ls
    

    get_label['y_prob']=get_label.max(axis=1)
    probability['y_prob']=get_label['y_prob'].tolist()
    probability['y_pred']=1
    for k in range(len(probability)):
        if probability['Epitope_pred'][k]!=probability['Epitope'][k]:
            probability['y_prob'][k]=0
            probability['y_pred'][k]=0


    df_tmp=pd.read_csv(testfile)
    for i in range(len(df_tmp)):
        probability.loc[probability.TCRB==df_tmp['LongA'][i],'y_true']=df_tmp['Affinity'][i]
        
    print(probability)
    probability.to_csv(result_path+'probability.csv',index=False)
    print('done saving!')
    sys.path.append('./LM/')
import pandas as pd
from preprocessing import prep
import os
import torch
import numpy as np
import csv
from predictor import utils
from pandas import read_csv, DataFrame
from preprocessing.prep import get_protseqs_ntseqs,determine_tcr_seq_nt,determine_tcr_seq_vj
from LM.bert_mdl import retrieve_model, compute_embs
from argparse import ArgumentParser,ArgumentDefaultsHelpFormatter,ArgumentTypeError
from bert_mdl import retrieve_model, extract_and_save_embeddings
import torch
import os

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

def fix(data_ori,data_new):
    data=pd.read_csv(data_ori)
    #'CDR3.beta', 'antigen_epitope','mhc.a','label','negative.source','license'
    data.rename(columns={'TRBV':'VB','TRBJ':'JB','TRAV':'VA','TRAJ':'JA','Affinity':'label'},inplace=True)
    
    data['Subject']='PMID:0'
    df=data[['Epitope', 'Subject', 'CDR3B', 'VB', 'JB', 'CDR3A', 'VA', 'JA', 'LongB','LongA']]
    df.to_csv(data_new,index=False)
        
def extract_epitopes(test_file,epitope_file):
    epis = np.loadtxt(test_file,usecols=(0),unpack=True,delimiter=',',skiprows=1,comments=None,dtype='str')
    epis_u,labels = prep.get_labels(epis)
    np.save(epitope_file,epis_u)


def make_embeddings(testfile, embed_model):
    print("start embedding")
    model = retrieve_model().to(device)
    # embed_model='embeddings/bert_vdjdb-b-example'#name for the model. Will be used if the model or results are saved
    # testfile='./data/vdjdb-b-example.csv' #filename of the used dataset
    delimiter=',' #Column delimiter in dataset file
    h_cdr3='CDR3B'#Column name for CDR3 of chain 1 in dataset file
    h_long='LongB'#Column name for Long TCR-sequence of chain 1 in dataset file
    seqs_per_file=50000#Maximum number of sequences in one embedding file. If there are more sequences, the embeddings will be split into several files.
    
    # extract and save some embeddings
    extract_and_save_embeddings(model, data_f_n=testfile, sequence_col=h_long, cdr3_col=h_cdr3, seqs_per_file=seqs_per_file, emb_name=embed_model,separator=delimiter)
    print("finished embedding")

def make_embeddings_a(testfile, embed_model):
    print("start embedding")
    model = retrieve_model().to(device)
    # embed_model='embeddings/bert_vdjdb-b-example'#name for the model. Will be used if the model or results are saved
    # testfile='./data/vdjdb-b-example.csv' #filename of the used dataset
    delimiter=',' #Column delimiter in dataset file
    h_cdr3='CDR3A'#Column name for CDR3 of chain 1 in dataset file
    h_long='LongA'#Column name for Long TCR-sequence of chain 1 in dataset file
    seqs_per_file=50000#Maximum number of sequences in one embedding file. If there are more sequences, the embeddings will be split into several files.
    
    # extract and save some embeddings
    extract_and_save_embeddings(model, data_f_n=testfile, sequence_col=h_long, cdr3_col=h_cdr3, seqs_per_file=seqs_per_file, emb_name=embed_model,separator=delimiter)
    print("finished embedding")
    
def Retraining_model_prediction(dataPath='./data/vdjdb-b-example.csv', mode= 'prediction',\
               epitope_labels= './data/unique_epitopes_vdjdb-b-example.npy', \
               chains='AB', h_cdr31='CDR3B', h_long1='LongB',  h_cdr32='CDR3A', h_long2='LongA',  h_v1='VB', h_j1='JB',h_v2='VA', h_j2='JA', model_file='./model/statedict_vdjdb-b-small.pt',  \
               embedfile1= 'embeddings/bert_vdjdb-b-example_0.bin',  embedfile2= 'embeddings/bert_vdjdb-b-example_0.bin',\
               predfile='outputs/preds-b-example.csv', result_path='outputs/b-example_probability.csv', batch_size=256,testfile=None):

    p = {      
        'dataset': dataPath,
        'chains': chains,
        'epitope_labels': epitope_labels,
        'embedtype': 'cdr3+context',
        'h_cdr31': h_cdr31,
        'h_long1': h_long1,
        'embedfile1': embedfile1,
        'h_cdr32':h_cdr32,
        'h_long2':h_long2,    
        'embedfile2': embedfile2,
        'delimiter': ',',
        
        'mode': mode, 
        # 'name': modelName,
        'folds': 'None',
        'fold_num': 0,  
        'h_epitope': 'Epitope',

        'batch_size': 512,
        'betas': [0.9, 0.999],
        'use_pos_weight': True,
        'dropouts': [0.1, 0.1],
        'iters_adam': 2500,
        'lr_conv': 0.0002,
        'lr_linear': 0.01,
        'T_anneal': 3000,
        'iters_swa': 500,
        'anneal_strategy': 'cos',
        'lr_swa': 0.0001,
        'T_anneal_swa': 300,
        
        'resultfile': './outputs/results.tsv',
        'print_every': 100,
        'lossfile': './outputs/loss_train.tsv',
        'lossfile_test': './outputs/loss_test.tsv',
        # 'model_folder': model_folder,
        'params_to_print': ['fold_num'],

        'binary': False,
        'binary_label': None,
        
        'append_oh': False,
        'kernel_sizes': [5,9,15,21,3],
        'pool': 'max',
        'h_v1': h_v1,
        'h_j1': h_j1,
        'h_v2': h_v2,
        'h_j2': h_j2,
        'h_nt1': 'none',
        'h_nt2': 'none',
        'guess_allele01': True,
        'model_file': model_file,  
        'use_LM': False,
        'num_features': 1024,
        'input_type': 'tcr+cdr3',
        'predfile': predfile,  
        'additional_columns': [],
        'decimals': 4,
        
        }

    p['two_chains']=len(p['chains'])>1
    p['table']=(p['params_to_print'],[p[h] for h in p['params_to_print']])
    p['save_intermediate'] = p['lossfile'].lower()!='none' or p['lossfile_test'].lower()!='none'
    
    data = read_csv(p['dataset'],delimiter=p['delimiter'],dtype=str,keep_default_na=False)
    epis_u = np.load(p['epitope_labels'])
    n_labels=len(epis_u)
    n_chains=len(p['chains'])
    
    resfile = p['predfile']
    with open(resfile,'w') as f:
        f.write(p['delimiter'].join(['TCR'+c for c in p['chains']]
                +p['additional_columns']+list(epis_u))+'\n')
    
    # Load TCRconv model
    model = utils.load_model(p['model_file'],p,p['num_features'],n_labels,device)
    # embeddings / embedding-models
    if p['use_LM']:
        LM=retrieve_model().to(device)
    
    # Get requested sequences for genes
    geneseqs={}
    for ic,chain in enumerate(p['chains']):
        c=str(ic+1)
        if p['input_type']=='cdr3+nt':
            geneseqs['protV'+c],geneseqs['protJ'+c],geneseqs['ntV'+c],geneseqs['ntJ'+c] = \
                    get_protseqs_ntseqs(chain=chain)
        elif p['input_type']=='cdr3+vj':
            geneseqs['protV'+c],geneseqs['protJ'+c],_,_ = get_protseqs_ntseqs(chain=chain)
    
    I=[]
    ts_all = [[] for c in p['chains']] # separate list for each chain
    icount,i0 = 0,0
    imax=len(data)-1
    
    for i in range(len(data)):
    
        if p['input_type']=='cdr3+nt':
            tcr12 = []
            for ic in range(n_chains):
                c=str(ic+1)
                t,_,_ = determine_tcr_seq_nt(data[p['h_nt'+c]][i],data[p['h_cdr3'+c]][i],geneseqs['protV'+c],
                        geneseqs['protJ'+c],geneseqs['ntV'+c],geneseqs['ntJ'+c],guess01=p['guess_allele01'])
                tcr12.append(t)
        elif p['input_type']=='cdr3+vj':
            tcr12 = []
            for ic in range(n_chains):
                c=str(ic+1)
                t = determine_tcr_seq_vj(data[p['h_cdr3'+c]][i],data[p['h_v'+c]][i],data[p['h_j'+c]][i],
                    geneseqs['protV'+c],geneseqs['protJ'+c],guess01=p['guess_allele01'])
                tcr12.append(t)
        elif p['input_type']=='cdr3':
            tcr12 = [data[p['h_cdr3'+str(ic+1)]][i] for ic in range(n_chains)]
        else: # tcr+cdr3 / tcr
            tcr12 = [data[p['h_long'+str(ic+1)]][i] for ic in range(n_chains)]
    
    
        # Check if (either) sequence is empty
        if np.any([t=='' for t in tcr12]):
            I.append(False)
            for ic in range(n_chains):
                ts_all[ic].append(tcr12[ic])
        else:
            I.append(True)
            for ic in range(n_chains):
                ts_all[ic].append(tcr12[ic])
            icount+=1
    
    
        if icount==p['batch_size'] or i==imax:
            ts_all =[np.array(t) for t in ts_all]
            if icount>0:
                I= np.array(I,dtype=bool)
                if p['use_LM']: # If LM is used, compute embeddings
                    cdr3s = data[p['h_cdr31']][i0:i+1][I].values
                    print('computing embeddings: {:d}-{:d}/{:d}'.format(i0,i,imax))
                    if p['embedtype']=='cdr3+context':
                        embeddings1 = compute_embs(LM, ts_all[0][I], cdr3s)
                        embeddings1 = utils.stack_embeddings(embeddings1,device,p['append_oh'],cdr3s)
                    else:
                        embeddings1 = compute_embs(LM, ts_all[0][I], None)
                        embeddings1 = utils.stack_embeddings(embeddings1,device,p['append_oh'])
    
                    if p['two_chains']:
                        cdr3s = data[p['h_cdr32']][i0:i+1][I].values
                        if p['embedtype']=='cdr3+context':
                            embeddings2 = compute_embs(LM,ts_all[1][I], cdr3s)
                            embeddings2 = utils.stack_embeddings(embeddings2,device,p['append_oh'],cdr3s)
                        else:
                            embeddings2 = compute_embs(LM, ts_all[1][I], None)
                            embeddings2 = utils.stack_embeddings(embeddings2,device,p['append_oh'])
    
                else: # 1-2 embedding dictionaries are used
    
                    #print(ts_all[0][I])
                    cdr3s = data[p['h_cdr31']][i0:i+1][I].values
                    cdr3max = utils.maxlen(cdr3s)
                    embeddings1 = utils.get_embeddings(utils.get_embedding_dict(p['embedfile1']),
                                cdr3s, ts_all[0][I], cdr3max, device, p['append_oh'])
                    if p['two_chains']:
                        cdr3s = data[p['h_cdr32']][i0:i+1][I].values
                        cdr3max = utils.maxlen(cdr3s)
                        embeddings2 = utils.get_embeddings(utils.get_embedding_dict(p['embedfile2']),
                                cdr3s,ts_all[1][I], cdr3max, device, p['append_oh'])
    
                # Predictions
                if p['two_chains']:
                    output = model(embeddings1,embeddings2).detach().cpu().numpy()
                else:
                    output = model(embeddings1).detach().cpu().numpy()
                output = utils.toprob(output)
    
                pred_ar = np.ones((len(I),n_labels),dtype=float)*np.nan
                pred_ar[I,:]= output
    
            else: # No proper sequences were found, add fillers
                pred_ar = np.ones((len(I),n_labels),dtype=float)*np.nan
    
            # append results to result file
            df = DataFrame(np.concatenate([np.expand_dims(ts_all[i],1) for i in range(len(ts_all))] \
                +[np.expand_dims(data[col].values[i0:i+1],1) for col in p['additional_columns']] \
                +[np.round(pred_ar,p['decimals'])],axis=1))
            df.to_csv(resfile,sep=p['delimiter'],mode='a',header=False,index=False)
           
            I = []
            ts_all = [[] for c in p['chains']]
            icount = 0
            i0 = i+1  
            
    #process result file
    df_result=pd.read_csv(resfile)
    probability = df_result[['TCRB']]
    
    df_ori=pd.read_csv(dataPath)
    print('df_ori',df_ori)
    for k in range(len(df_ori)):
        probability.loc[probability.TCRB==df_ori['LongA'][k],'Epitope']=df_ori['Epitope'][k]
    print(probability)
    get_label=df_result[df_result.columns[1:]]
    get_label = get_label.apply(pd.to_numeric, errors='coerce')

    pre_epi_ls=get_label.idxmax(axis=1).tolist()   
    probability['Epitope_pred']=pre_epi_ls
    

    get_label['y_prob']=get_label.max(axis=1)
    probability['y_prob']=get_label['y_prob'].tolist()
    probability['y_pred']=1
    for k in range(len(probability)):
        if probability['Epitope_pred'][k]!=probability['Epitope'][k]:
            probability['y_prob'][k]=0
            probability['y_pred'][k]=0


    df_tmp=pd.read_csv(testfile)
    for i in range(len(df_tmp)):
        probability.loc[probability.TCRB==df_tmp['LongA'][i],'y_true']=df_tmp['Affinity'][i]
        
    print(probability)
    probability.to_csv(result_path+'probability.csv',index=False)
    print('done saving!')
    

In [None]:
model_name='TCRconv-fullA'

testfile_path ="../data/test_CDR3B_others.csv"
file_out="../data/TCRconv-A/"
os.makedirs(file_out, exist_ok=True)

trainfile_new=file_out+"train_CDR3B_others.csv"
testfile_new=file_out+"test_CDR3B_others.csv"
fix(testfile_path,testfile_new)
trainfile_out="../data/TCRconv-A/"
epitope_file=trainfile_out+'train_unique_epitopes.npy'

embed_model='./embeddings/test/TCRconv_fullA/bert'
os.makedirs('./embeddings/test/TCRconv_fullA/bert', exist_ok=True)
make_embeddings(testfile_new, embed_model)
save_model_path= "../Retraining_model/Retraining_model.pt"
result_path="../result_path/Retraining_model_prediction"
os.makedirs(result_path,exist_ok=True)
initial_result"../result_path/Retraining_model_prediction/test.csv"
Retraining_model_prediction(dataPath=testfile_new, mode= 'prediction',\
           epitope_labels= epitope_file, \
           chains='A', h_cdr31='CDR3A', h_long1='LongA',  h_v1='VA', h_j1='JA',  model_file=save_model_path, \
           embedfile1= embed_model+'_0.bin',\
           predfile=initial_result, result_path= result_path, batch_size=256, testfile=testfile_path
           )


In [None]:
model_name='TCRconv-fullA'
testfile_path ="../data/Validation_CDR3B_others.csv"
file_out="../data/TCRconv-A/"
os.makedirs(file_out, exist_ok=True)

trainfile_new=file_out+"train_CDR3B_others.csv"
testfile_new=file_out+"Validation_CDR3B_others.csv"
fix(testfile_path,testfile_new)
trainfile_out="../data/TCRconv-A/"
epitope_file=trainfile_out+'train_unique_epitopes.npy'

embed_model='./embeddings/Validation/TCRconv_fullA/bert'
os.makedirs('./embeddings/Validation/TCRconv_fullA/bert', exist_ok=True)
make_embeddings(testfile_new, embed_model)
save_model_path= "../Retraining_model/Retraining_model.pt"
result_path="../result_path/Retraining_model_prediction"
os.makedirs(result_path,exist_ok=True)
initial_result"../result_path/Retraining_model_prediction/Validation.csv"
Retraining_model_prediction(dataPath=testfile_new, mode= 'prediction',\
           epitope_labels= epitope_file, \
           chains='A', h_cdr31='CDR3A', h_long1='LongA',  h_v1='VA', h_j1='JA',  model_file=save_model_path, \
           embedfile1= embed_model+'_0.bin',\
           predfile=initial_result, result_path= result_path, batch_size=256, testfile=testfile_path
           )
