<a href="https://colab.research.google.com/github/JiahuaQu/conga/blob/master/TCRdist_GPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**`Please note this is an experimental code for accelerated TCRdist computation on CPU and GPU`**

It was not extensively tested. Before usage upload params_v2.tsv, tmp_tcr.tsv and TCRdist_matrix_mega.tsv.

contact: mikhail.pogorelyy@stjude.org (Mikhail Pogorelyy)

In [None]:
import cupy as mx #cuda python backend, connect to T4 runtime or other with GPU
import numpy as np
#import numpy as mx #use this for CPU only
#import mlx.core as mx #use this for apple silicon
import pandas as pd
import time

In [None]:
def pad_center(seq, target_length): # function to pad center with gaps until target length
   seq_length = len(seq)
   if seq_length >= target_length:
       return seq[:target_length]
   else:
       total_padding = target_length - seq_length
       first_half = seq[:seq_length // 2]
       second_half = seq[seq_length // 2:]
       return first_half + ['_'] * total_padding + second_half

#load TCR file
tst_tcr = pd.read_csv("tmp_tcr.tsv", sep="\t")

#load encoder dictionary
params_df = pd.read_csv("params_v2.tsv", sep="\t", header=None, names=["feature", "value"])
params_vec = dict(zip(params_df["feature"], params_df["value"]))
#load substitution matrix
submat = mx.array(np.loadtxt('TCRdist_matrix_mega.tsv', delimiter='\t', dtype=np.int16)) #this is substitution matrix

#encode TCRs
cdr3amat = np.array([pad_center(list(seq), 29) for seq in tst_tcr['cdr3a']])
cdr3amatint = np.vectorize(params_vec.get)(cdr3amat)
cdr3bmat = np.array([pad_center(list(seq), 29) for seq in tst_tcr['cdr3b']])
cdr3bmatint = np.vectorize(params_vec.get)(cdr3bmat)

cols_to_use = slice(3, -2) #truncate CDR3s

encoded = np.column_stack([
    np.vectorize(params_vec.get)(tst_tcr['va']),
    cdr3amatint[:,cols_to_use],
    np.vectorize(params_vec.get)(tst_tcr['vb']),
    cdr3bmatint[:,cols_to_use]
])


In [None]:
#convert to cupy arrays if cupy imported
tcrs1=mx.array(encoded).astype(mx.uint8)
tcrs2=mx.array(encoded).astype(mx.uint8) # could be a different dataset, like a database

kbest=1000 # nbest neighbour we are looking for
chunk_size=min(tcrs1.shape[0],20000000//tcrs2.shape[0]) # decrease magic constant to limit memory consumption by temporaty 3d tensor. Increase if you have more powerful gpu
print('total number of chunks', tcrs1.shape[0]//chunk_size)

total number of chunks 20


In [None]:
tcrs1.shape # (20000, 50)

(20000, 50)

In [None]:
start_time = time.time()

result=mx.zeros((tcrs1.shape[0],kbest),dtype=mx.uint32) #initialize result array for indices
for ch in range(0, tcrs1.shape[0], chunk_size): #we process in chunks across tcr1 to not run out of memory
    #print('Processing chunk', ch)
    chunk_end = min(ch + chunk_size, tcrs1.shape[0])
    row_range = slice(ch, chunk_end)
    #mx.sum(submat[tcrs1[row_range, None, :], tcrs2[ None,:, :]],axis=2)# if you just want TCRdist matrix for this chunk
    result[row_range,:]=mx.argpartition(mx.sum(submat[tcrs1[row_range, None, :], tcrs2[ None,:, :]],axis=2),axis=1,kth=kbest)[:,0:kbest] #note that this does not guarantee elements are sorted within partition! It is also annoying that sum produce 32 bit ints here, 16 bit would be enough

end_time = time.time()

print(result)
print(f"Time taken: {end_time - start_time:.6f} seconds") #6.7 seconds on T4 GPU, 120 seconds on T4 CPU


[[    0  8427 16854 ... 17448 17501 18398]
 [    1  8428 16855 ... 18036 18181 18283]
 [    2  8429 16856 ... 10387 11150 11254]
 ...
 [ 3143 11570 19997 ...  6140  6744  7490]
 [ 3144 11571 19998 ... 15842 16659 16994]
 [ 3145 11572 19999 ...  9970 10082 11150]]
Time taken: 6.797993 seconds
