Process the prediction got from CLAIRE (Glazer 2022) server for the associations of pairs formed between 10008 Zheng 2021 TCRs and 146 HLA-I alleles

    build a dictionary for the pairs and their score
    get the scores for all pairs (including duplicates)
    reformat to matrix and save, with rows for TCRs and columns for HLAs

In [1]:
import pandas as pd
import numpy as np

from collections import defaultdict
from collections import Counter

import re
import pickle

In [2]:
# Process the pairs and their scores to get scores for all pairs (including duplicates)
# Load all pairs and scores files

batch_score_folder = "../results/st7_batch_score_folder/"

load_tcrb = []
load_vb = []
load_hla = []
load_score = []

for i in range(1, 21):
    
    df_cur_scores = pd.read_csv(batch_score_folder+"output_"+str(i)+".csv", header=0)
    
    load_tcrb += df_cur_scores.tcrb.tolist()
    load_vb += df_cur_scores.vb.tolist()
    load_hla += df_cur_scores.mhc.tolist()
    load_score += df_cur_scores.prediction.tolist()
    
print(len(load_tcrb))
print(len(load_vb))
print(len(load_hla))
print(len(set(list(zip(load_tcrb, load_vb, load_hla)))))

383026
383026
383026
383026


In [5]:
# whether all pairs matches between the input and output files
file_match_flags = []

for i in range(20):
    
    df_cur_pairs = pd.read_csv("../results/st7_batch_folder/st7_zheng_2021_hla_109_mcpas_glazer_online_unique_pairs_batch_"+str(i)+".csv", header=0)
    df_cur_scores = pd.read_csv(batch_score_folder+"output_"+str(i+1)+".csv", header=0)
    
    cur_tcrb_flag = (sum([x==y for x, y in zip(df_cur_pairs.tcrb.tolist(), df_cur_scores.tcrb.tolist())]) == df_cur_pairs.shape[0])
    cur_vb_flag = (sum([x==y for x, y in zip(df_cur_pairs.vb.tolist(), df_cur_scores.vb.tolist())]) == df_cur_pairs.shape[0])   
    cur_hla_flag = (sum([x==y for x, y in zip(df_cur_pairs.mhc.tolist(), df_cur_scores.mhc.tolist())]) == df_cur_pairs.shape[0])
    
    cur_flag = ((cur_tcrb_flag and cur_vb_flag) and cur_hla_flag)
    
    file_match_flags += [cur_flag]
    
print(file_match_flags)

# double check whether all pairs matches those in the original unique pair file

csv_unique_pairs = pd.read_csv("../results/st7_zheng_2021_hla_109_mcpas_glazer_online_unique_pairs.csv", header=0)
print(sum([x==y for x,y in zip(csv_unique_pairs.tcrb.tolist(), load_tcrb)])/csv_unique_pairs.shape[0])
print(sum([x==y for x,y in zip(csv_unique_pairs.vb.tolist(), load_vb)])/csv_unique_pairs.shape[0])
print(sum([x==y for x,y in zip(csv_unique_pairs.mhc.tolist(), load_hla)])/csv_unique_pairs.shape[0])

[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]
1.0
1.0
1.0


In [6]:
# build a dictionary for the pairs and their corresponding scores

score_dict = defaultdict(float)

for tcrb, vb, hla, score in zip(load_tcrb, load_vb, load_hla, load_score):
    score_dict[(tcrb, vb, hla)] = score
    
print(len(score_dict))
print(max(score_dict.values()))
print(min(score_dict.values()))

383026
0.9948988556861876
0.0046807117760181


In [9]:
# load the pairs with duplicates, in order to get scores for all pairs
csv_pos = pd.read_csv("../results/st7_zheng_2021_hla_109_mcpas_glazer_online.csv", header=0)
csv_pos.shape

(1090872, 8)

In [10]:
tcrb_w_dup = csv_pos.tcrb.tolist()
vb_w_dup = csv_pos.vb.tolist()
mhc_w_dup = csv_pos.mhc.tolist()

In [11]:
raw_columns = mhc_w_dup[:109]

In [12]:
scores_w_dup = [score_dict[(tcrb, vb, hla)] for tcrb, vb, hla in zip(tcrb_w_dup, vb_w_dup, mhc_w_dup)]

print(csv_pos.shape)
len(scores_w_dup)

(1090872, 8)


1090872

In [13]:
len(score_dict)

383026

In [14]:
yhat = np.array(scores_w_dup)

len(tcrb_w_dup)/109

10008.0

In [15]:
yhat_reshape = yhat.reshape(int(len(tcrb_w_dup)/109), len(raw_columns))

yhat_reshape[:2, :6]

array([[0.51867741, 0.92720938, 0.40561506, 0.87267429, 0.8846783 ,
        0.84183466],
       [0.51867741, 0.92720938, 0.40561506, 0.87267429, 0.8846783 ,
        0.84183466]])

In [16]:
scores_w_dup[:6]

[0.5186774134635925,
 0.9272093772888184,
 0.4056150615215301,
 0.872674286365509,
 0.884678304195404,
 0.8418346643447876]

In [17]:
scores_w_dup[109:(109+6)]

[0.5186774134635925,
 0.9272093772888184,
 0.4056150615215301,
 0.872674286365509,
 0.884678304195404,
 0.8418346643447876]

In [18]:
star_columns = [x[:5] + "*" + x[5:] for x in raw_columns]

df_scores = pd.DataFrame(yhat_reshape, columns=star_columns)

In [19]:
# extend for all 146 HLAs following the format for DePTH, 
# for the convenience of later processing

convert_filename = "st7_hla_chowell_146_mcpas_convert_39.csv"
df_convert = pd.read_csv("../results/"+convert_filename, header = 0)

In [20]:
df_convert[:10]

Unnamed: 0,hla_chowell_146,hla_convert
0,A0101,HLA-A01:01
1,A0102,HLA-A01:01
2,A0103,HLA-A01:01
3,A0201,HLA-A02:01
4,A0202,HLA-A02:02
5,A0203,HLA-A02:03
6,A0205,HLA-A02:05
7,A0206,HLA-A02:06
8,A0207,HLA-A02:07
9,A0217,HLA-A02:17


In [22]:
dict_extend = defaultdict(list)

for chowell, claire in zip(df_convert.hla_chowell_146.tolist(), 
                           df_convert.hla_convert.tolist()):
    claire_star = claire[:5] + "*" + claire[5:]
    dict_extend[chowell] = df_scores[claire_star].tolist()

    
len(dict_extend)

146

In [29]:
df_extend = pd.DataFrame.from_dict(dict_extend)

df_extend.to_csv("../results/st8_Glazer_2022_server_on_zheng_2021_pos_extended_hlas_146.csv", index=False)

In [25]:
df_extend.shape

(10008, 146)