In [1]:
import os 
import pandas as pd 
import numpy as np
from tqdm import tqdm 
import scipy.stats as stats
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer 

from process_data_new import process_data

## Data loading 

In [2]:
# data_path = os.path.join(os.getcwd(), os.pardir, "data")
# emerson_path =os.path.join(data_path, "emerson", "emerson_processed")
# test_data = pd.read_csv(os.path.join(emerson_path, "whole_seqs_nn_test.tsv"), sep = '\t')

test_data = process_data(train=False)

Reading the .tsv file...
Creating the allele codes...
Loading the sequences...
Masking the dataset...


100%|██████████| 26427243/26427243 [04:34<00:00, 96415.74it/s] 


In [3]:
V_data, CDR3_data, J_data, tgt_data =test_data

In [4]:
CDR3_data[3]

'A S S P D R D S P L H F'

## Evaluation

Evaluation follows the steps of the evaluation utilized in TCRpeg model.

https://github.com/jiangdada1221/TCRpeg/blob/main/tcrpeg/evaluate.py

In [5]:
counts = {}
for v, cdr, seq in tqdm(zip(V_data, CDR3_data,tgt_data), total = len(CDR3_data),leave = True, position = 0):

    if  'C ' + cdr not in counts.keys():
        counts['C ' +cdr] = [1,[seq], [(len(v)+1, len(cdr))]]
    else:
        counts['C ' +cdr][0] += 1
        counts['C ' +cdr][1].append(seq)
        counts['C ' +cdr][2].append((len(v)+1, len(cdr)))

100%|██████████| 26427243/26427243 [02:55<00:00, 150181.37it/s]


In [6]:
c_data_, seqs_, lengths_ = [], [], []

for v in counts.values():
    c_data_.append(v[0])
    seqs_.append(v[1])
    lengths_.append(v[2]) 


cdr3_seqs_ = list(counts.keys())
c_data, cdr3_seqs, seqs, lengths = [], [], [], []

for i in range(len(seqs_)):  #only need seqs that has appearance > 2 ??why?? --> No reason found for this but done since it is done in GRU paper
    if c_data_[i] > 2:
        c_data.append(c_data_[i])
        cdr3_seqs.append(cdr3_seqs_[i])
        seqs.append(seqs_[i])
        lengths.append(lengths_[i])

p_data = np.array(c_data)
sum_p = np.sum(p_data)
p_data = p_data / sum_p #normalized probability

In [7]:
p_data

array([4.35998147e-06, 3.96361952e-07, 2.47726220e-06, ...,
       2.97271464e-07, 2.97271464e-07, 2.97271464e-07])

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint =  torch.load("../code/model_parallel_current/030823_customprotBERT_parallel_checkpoint.pth",  map_location="cpu")
model = checkpoint["model"]
model.to(device)
model.eval()

## No batch version

In [None]:
def sampling(model, src, length, cdr):
    cdr_logits = model(src,
                       length)
    probabilities = cdr_logits.softmax(-1)
    seq_prob = 1
    
    for i in range(cdr.shape[-1]):
        seq_prob *= float(probabilities[:,i,int(cdr[:,i])])
    return seq_prob

In [None]:
tokenizer = AutoTokenizer.from_pretrained('Rostlab/prot_bert_bfd', add_special_tokens=False)
torch.set_default_dtype(torch.float64)
record = np.zeros(len(seqs))

with torch.no_grad():
    for i, s in tqdm(enumerate(seqs), total= len(seqs), leave=True, position= 0):
        s_prob = []
        for j in range(len(counts[s][1])):
            #print(f'seq {j+1} of {len(counts[s][1])} ')
            src = tokenizer(counts[s][1][j], return_tensors = 'pt')['input_ids'][:, :-1] #remove [SEP]
            length = torch.tensor(counts[s][2][j]).unsqueeze(0)
            cdr = tokenizer(s[1:], return_tensors = 'pt')['input_ids'][:, 1:-1] #remove [CLS] and [SEP]
            s_prob.append(sampling(model, src, length, cdr))
        record[i] = sum(s_prob)/len(s_prob)
        #print(record[i])

## Batch version

In [None]:
import sys 

sys.path.insert(0, os.path.join(os.getcwd(), os.pardir, "code", "model_parallel_current"))

from CDR3Dataset import CDR3Dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained('Rostlab/prot_bert_bfd', add_special_tokens=False)

In [None]:
#v_test = 
cdr3_test = [ s[1:] for s, c in tqdm(zip(cdr3_seqs, c_data), total = len(cdr3_seqs), leave= True, position= 0) for _ in range(c)]

lengths_test = []
tgt_test = [] 
for s, l in tqdm(zip(seqs, lengths), total= len(seqs), leave= True, position= 0):
    tgt_test += s
    lengths_test += l

v_test = []
j_test = []
for seq, lens in tqdm(zip(tgt_test, lengths_test), total = len(tgt_test), leave =True, position = 0):
    v_test.append(seq[:lens[0]-1])
    j_test.append(seq[lens[0]+lens[1]+1:]) 


In [None]:

cdr3_test[11]

In [None]:
tgt_test[11]

In [None]:
dataset = CDR3Dataset(V = v_test, CDR3 = cdr3_test, J = j_test, tgt = tgt_test, tokenizer=tokenizer, evaluate=True)

In [None]:
dataset[0]

In [None]:
test_loader = DataLoader(dataset, shuffle=False, batch_size=64)

In [None]:
def sampling(model, src, length, cdr, pad_mask):
    #print(length)
    cdr_logits = model(src,
                       length,
                       pad_mask)
    # Continue here
    probabilities = cdr_logits.log_softmax(-1)
    
    #think about batch of cdr3 sequences 
    target_onehot = torch.zeros(probabilities.shape)
    
    cdr = torch.nn.functional.pad(input=cdr, pad=(0, model.CDR3_max_length -cdr.shape[-1]))
    
    replace_tensor = torch.ones(probabilities.shape[0], model.CDR3_max_length, 30)
    mask = torch.where(cdr>0, torch.tensor(1), cdr).unsqueeze(-1).expand(probabilities.shape)
    
    replace_tensor = replace_tensor*mask
    #print(replace_tensor.shape)
    #print(replace_tensor[0,:,:])
    
    #replace_tensor = torch.cat([torch.ones((probabilities.shape[0], int(length[0, 1]), 30)), torch.zeros((probabilities.shape[0], 26-int(length[0,1]),30))],dim =1)
    '''
    print(replace_tensor.shape)
    print(target_onehot.shape)
    print(cdr.view([src.shape[0], -1, 1])[0,:,:])
    '''
    # The repeat might not be needed if cdr is "big" tensor
    target_onehot.scatter_(dim=2, index = cdr.view([src.shape[0], -1, 1]) ,src=replace_tensor)
    seq_prob = target_onehot * probabilities
    
    #print(torch.exp(seq_prob[0,:,:].sum(-1).sum(-1)))
    
    seq_prob = torch.exp(seq_prob.sum(-1).sum(-1))
    print(seq_prob)
    
    return seq_prob

In [None]:
'''
max_length = 0
for s in seqs:
    for i in range(len(counts[s][1])):
        max_length = max(max_length, len(counts[s][1][i].split())+3)
'''

In [None]:
torch.set_default_dtype(torch.float64)

seq_probs = []
with torch.no_grad():
    for batch in tqdm(test_loader):
        #if i > 0: break 
        batch = {k:v.to(device) for k,v in batch.items()}
        
        probs = sampling(model = model,
                             src = batch['seq'],
                             length = batch['length'],
                             cdr = batch['CDR3_label'],
                             pad_mask = batch['pad_mask'])
        seq_probs.append(probs)
seq_probs = torch.cat(seq_probs)
        

In [None]:
cdr3_probs ={}
for i, cdr in enumerate(cdr3_test):
    if i > 63 : break 
    if cdr not in cdr3_probs.keys():
        cdr3_probs[cdr] = [float(seq_probs[i])]
    else:
        cdr3_probs[cdr].append(float(seq_probs[i]))    

In [None]:
cdr3_probs

In [None]:
for k, v in cdr3_probs.items():
    cdr3_probs[k] = sum(v)/len(v)

In [None]:
cdr3_probs

# Load the P_infer

In [8]:
import pickle

In [24]:
with open('../code/P_infer_CDR3_seqs.pkl', 'rb') as f:
    cdr3_probs = pickle.load(f)
f.close()

In [35]:
record = np.array(list(cdr3_probs.values()))

In [42]:
record_sum = np.sum(record)
record = record/record_sum
# kl = kl_divergence(p_data,record)
corr = stats.pearsonr(p_data,record)[0]
print('Pearson correlation coefficient are : {}'.format(str(round(corr,4))))
#return corr,p_data,record

Pearson correlation coefficient are : 0.0365


In [49]:
cdr3_seqs[:5]

['C A S S P D R D S P L H F',
 'C A S S R S T G Q G Y T F',
 'C A S S F R A D T E A F F',
 'C A S S L A Y E Q Y F',
 'C A S S F T G D T E A F F']