In [None]:
cd /rsch/Snowxue/TCR-DeepInsight/

## Import 

In [None]:
from tcr_deep_insight.model.modeling_bert._model import (
    TRabModelingBertForPseudoSequence,
    TCRpMHCPairAttention,
    TRabModelingBertForPseudoSequenceWithContactModule
)
from tcr_deep_insight.model.tokenizers._tokenizer import (
    tokenize_tcr_pseudo_sequence_to_fixed_length, 
    trab_tokenizer_for_pseudosequence,
    tokenize_to_fixed_length
)
from tcr_deep_insight.model.modeling_bert._collator import AminoAcidsCollator
from tcr_deep_insight.model.modeling_bert._config import get_human_config


from tcr_deep_insight.utils._tcr_definitions import (
    _get_hla_pseudo_sequence,
    blosum_align,
    HLA_I_PSEUDO_INDEX,
)

import os

import pickle

import pickle
import numpy as np 
from tcr_deep_insight.utils._tcr_definitions import (
    _get_hla_pseudo_sequence,
    blosum_align,
    HLA_I_PSEUDO_INDEX,
)


import torch

hla_pseudo_sequence = _get_hla_pseudo_sequence()

In [None]:
import pandas as pd
df_binding = pd.read_parquet("./data/tcrpmhc_pairing.parquet")

df_scrna = pd.read_parquet("./data/20240401_huARdb_v2_5.obs.parquet")

import datasets
ds_scrna = datasets.load_from_disk("./data/datasets/scrna_dataset_TCRab/")

In [None]:
df_scrna['bvpseudo'] = list(map(lambda x: ':'.join(x.split(":")[:2]), df_scrna['pseudosequence']))
df_scrna_bvpseudo_index = {bvpseudo:np.argwhere(np.array(df_scrna['bvpseudo'] == bvpseudo)).flatten() for bvpseudo in np.unique(df_scrna['bvpseudo'])}
df_scrna['avpseudo'] = list(map(lambda x: ':'.join(x.split(":")[3:5]), df_scrna['pseudosequence']))
df_scrna_avpseudo_index = {avpseudo:np.argwhere(np.array(df_scrna['avpseudo'] == avpseudo)).flatten() for avpseudo in np.unique(df_scrna['avpseudo'])}


In [None]:
config = get_human_config(bert_type="small", vocab_size=36)
model = TRabModelingBertForPseudoSequenceWithContactModule(config).to("cuda:0")

## Generating contact dataset

In [None]:
contact_dataset = []

for i in list(filter(lambda x: x.endswith('.pkl'), os.listdir("./data/processed_pairing_data/"))):
    with open("./data/processed_pairing_data/" + i, "rb") as f:
        data = pickle.load(f)

    if 'tcr_pmhc_pseudo_distogram' not in data:
        continue

    pmhc_pseudoseq = hla_pseudo_sequence["A*02:01"]["pseudosequence"] + ':' + data['peptide']
    
    tcr_pseudoseq = ':'.join([
        data[k] for k in ['cdr1a','cdr2a', 'cdr3a', 'cdr1b','cdr2b', 'cdr3b']
    ])

    tcr_sequence, tcr_input_ids, tcr_attention_mask = tokenize_tcr_pseudo_sequence_to_fixed_length(tcr_pseudoseq)
    tcr_input_ids = np.array(tcr_input_ids)
    tcr_attention_mask = np.array(tcr_attention_mask)

    pmhc_sequence, pmhc_input_ids, pmhc_attention_mask = tokenize_to_fixed_length(pmhc_pseudoseq, 50)
    pmhc_input_ids = np.array(pmhc_input_ids)
    pmhc_attention_mask = np.array(pmhc_attention_mask)

    contact_dataset.append({
        "tcr_sequence": tcr_sequence,
        "tcr_input_ids": tcr_input_ids,
        "tcr_attention_mask": tcr_attention_mask,
        "pmhc_sequence": pmhc_sequence,
        "pmhc_input_ids": pmhc_input_ids,
        "pmhc_attention_mask": pmhc_attention_mask,
        "distogram": data['tcr_pmhc_pseudo_distogram'],
        "pdb_id": i
    })


collator = AminoAcidsCollator(
    mask_token_id=4,
    max_length=100,
    mlm_probability=0.15,
)


In [None]:
with open("./data/processed_contact_data.pkl",'wb+') as f:
    pickle.dump(contact_dataset,f)

In [None]:
with open("./data/processed_contact_data.pkl",'rb') as f:
    contact_dataset = pickle.load(f)

In [None]:
model = model.to("cuda:0")

## Split Training and Testing binding dataset

In [None]:
from sklearn.model_selection import train_test_split
agg = df_binding.groupby("peptide").agg({
    'tcr_pseudosequence': list,
    'pmhc_pseudosequence': list,
})
train_test_indices = list(map(lambda x: train_test_split(range(len(x)),test_size=0.2),agg['tcr_pseudosequence']))
agg_train = pd.DataFrame(
    list(
        map(
            lambda x: (list(np.array(x[1])[x[0][0]]), list(np.
array(x[2])[x[0][0]])),
            zip(
                train_test_indices,
                agg["tcr_pseudosequence"],
                agg["pmhc_pseudosequence"],
            ),
        )
    ),
    index=agg.index,
    columns=agg.columns
)
agg_test = pd.DataFrame(
    list(
        map(
            lambda x: (list(np.array(x[1])[x[0][1]]), list(np.
array(x[2])[x[0][1]])),
            zip(
                train_test_indices,
                agg["tcr_pseudosequence"],
                agg["pmhc_pseudosequence"],
            ),
        )
    ),
    index=agg.index,
    columns=agg.columns
)

In [None]:
with open("./data/20240430_agg_train.pkl","wb+") as f:
    pickle.dump(agg_train,f)
with open("./data/20240430_agg_test.pkl","wb+") as f:
    pickle.dump(agg_test,f)

In [None]:
with open("./data/20240430_agg_train.pkl","rb") as f:
    agg_train = pickle.load(f)
with open("./data/20240430_agg_test.pkl","rb") as f:
    agg_test = pickle.load(f)

## Training contact task

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for _ in range(50):
    epoch_loss = 0
    epoch_contact_loss = 0
    for i in range(0, len(contact_dataset)):
        
        tcr_input_ids = torch.from_numpy(contact_dataset[i]['tcr_input_ids']).unsqueeze(0).to("cuda:0")
        tcr_attention_mask = torch.from_numpy(contact_dataset[i]['tcr_attention_mask']).unsqueeze(0).to(torch.float32).to("cuda:0")
        pmhc_input_ids = torch.from_numpy(contact_dataset[0]['pmhc_input_ids']).unsqueeze(0).to("cuda:0")
        pmhc_attention_mask = torch.from_numpy(contact_dataset[i]['pmhc_attention_mask']).unsqueeze(0).to(torch.float32).to("cuda:0")
        distogram = torch.from_numpy(contact_dataset[i]['distogram']).unsqueeze(0).to("cuda:0")

        output1 = model(
            tcr_input_ids=tcr_input_ids,
            tcr_attention_mask=tcr_attention_mask,
            pmhc_input_ids=pmhc_input_ids,
            pmhc_attention_mask=pmhc_attention_mask,
            tcr_pmhc_distogram=distogram,
        )
        
        loss = output1['contact_loss']
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_contact_loss += output1['contact_loss'].item()
        
    print(epoch_loss, epoch_contact_loss)

## Training binding task only

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for _ in range(30):
    epoch_loss = 0
    epoch_binding_loss = 0
    epoch_contact_loss = 0
    
    indices = list(map(lambda x: np.random.choice(range(len(x)), len(x), replace=False), agg_train['tcr_pseudosequence']))
    binding_data = []
    batch_binding_data = FLATTEN([list(zip(np.array(x)[i], np.array(y)[i])) for x,y,i in zip(
        agg_train['tcr_pseudosequence'],
        agg_train['pmhc_pseudosequence'], 
        indices
    )])

    for tcr_pseudoseq,pmhc_pseudoseq in batch_binding_data:
        tcr_sequence, tcr_input_ids, tcr_attention_mask = tokenize_tcr_pseudo_sequence_to_fixed_length(tcr_pseudoseq)
        pmhc_sequence, pmhc_input_ids, pmhc_attention_mask = tokenize_to_fixed_length(pmhc_pseudoseq, 50)
        tcr_input_ids = np.array(tcr_input_ids)
        tcr_attention_mask = np.array(tcr_attention_mask)
        pmhc_input_ids = np.array(pmhc_input_ids)
        pmhc_attention_mask = np.array(pmhc_attention_mask)
        binding_data.append({
            "tcr_input_ids": tcr_input_ids,
            "tcr_attention_mask": tcr_attention_mask,
            "pmhc_sequence": pmhc_sequence,
            "pmhc_input_ids": pmhc_input_ids,
            "pmhc_attention_mask": pmhc_attention_mask,
            "binding": 1
        })

        avpseudo = ':'.join(tcr_pseudoseq.split(":")[:2])
        bvpseudo = ':'.join(tcr_pseudoseq.split(":")[3:5])
        if avpseudo in df_scrna_avpseudo_index:
            indices = df_scrna_avpseudo_index[avpseudo]
            if len(indices) > 0:
                decoy_data = ds_scrna[int(np.random.choice(indices.flatten()))]
                tcr_input_ids = decoy_data['input_ids']
                tcr_attention_mask = decoy_data['attention_mask']
                binding_data.append({
                    "tcr_input_ids": tcr_input_ids,
                    "tcr_attention_mask": tcr_attention_mask,
                    "pmhc_sequence": pmhc_sequence,
                    "pmhc_input_ids": pmhc_input_ids,
                    "pmhc_attention_mask": pmhc_attention_mask,
                    "binding": 0
                })
        if bvpseudo in df_scrna_bvpseudo_index:
            indices = df_scrna_bvpseudo_index[bvpseudo]
            if len(indices) > 0:
                decoy_data = ds_scrna[int(np.random.choice(indices.flatten()))]
                tcr_input_ids = decoy_data['input_ids']
                tcr_attention_mask = decoy_data['attention_mask']
                binding_data.append({
                    "tcr_input_ids": tcr_input_ids,
                    "tcr_attention_mask": tcr_attention_mask,
                    "pmhc_sequence": pmhc_sequence,
                    "pmhc_input_ids": pmhc_input_ids,
                    "pmhc_attention_mask": pmhc_attention_mask,
                    "binding": 0
                })
        if avpseudo in df_scrna_avpseudo_index and bvpseudo in df_scrna_bvpseudo_index:
            indices = list(set(df_scrna_avpseudo_index[avpseudo]).intersection(set(df_scrna_bvpseudo_index[bvpseudo])))
            if len(indices) > 0:
                decoy_data = ds_scrna[int(np.random.choice(indices))]
                tcr_input_ids = decoy_data['input_ids']
                tcr_attention_mask = decoy_data['attention_mask']
                binding_data.append({
                    "tcr_input_ids": tcr_input_ids,
                    "tcr_attention_mask": tcr_attention_mask,
                    "pmhc_sequence": pmhc_sequence,
                    "pmhc_input_ids": pmhc_input_ids,
                    "pmhc_attention_mask": pmhc_attention_mask,
                    "binding": 0
                })

    batch_decoy = ds_scrna[np.random.choice(np.arange(len(ds_scrna)), len(batch_binding_data), replace=False)]

    for tcr_input_ids, tcr_attention_mask, pmhc_pseudoseq in zip(batch_decoy['input_ids'],batch_decoy['attention_mask'],list(map(lambda x: x[1], batch_binding_data))):
        tokenize_to_fixed_length(pmhc_pseudoseq, 50)
        binding_data.append({
            "tcr_input_ids": tcr_input_ids,
            "tcr_attention_mask": tcr_attention_mask,
            "pmhc_sequence": pmhc_sequence,
            "pmhc_input_ids": pmhc_input_ids,
            "pmhc_attention_mask": pmhc_attention_mask,
            "binding": 0
        })
    batch_binding_dataset = datasets.Dataset.from_pandas(pd.DataFrame(binding_data))
    batch_binding_dataset = batch_binding_dataset.shuffle()
            

    n_per_batch = 10
    for i in tqdm.trange(0, len(batch_binding_dataset) // n_per_batch):
        
        tcr_input_ids = torch.tensor(batch_binding_dataset[i*n_per_batch:i*n_per_batch+n_per_batch]['tcr_input_ids']).to("cuda:0")
        tcr_attention_mask = torch.tensor(batch_binding_dataset[i*n_per_batch:i*n_per_batch+n_per_batch]['tcr_attention_mask']).to(torch.float32).to("cuda:0")
        pmhc_input_ids = torch.tensor(batch_binding_dataset[i*n_per_batch:i*n_per_batch+n_per_batch]['pmhc_input_ids']).to("cuda:0")
        pmhc_attention_mask = torch.tensor(batch_binding_dataset[i*n_per_batch:i*n_per_batch+n_per_batch]['pmhc_attention_mask']).to(torch.float32).to("cuda:0")
        binding =  torch.tensor(batch_binding_dataset[i*n_per_batch:i*n_per_batch+n_per_batch]['binding']).to("cuda:0").to(torch.float32).unsqueeze(1)
        
        output2 = model(
            tcr_input_ids=tcr_input_ids,
            tcr_attention_mask=tcr_attention_mask,
            pmhc_input_ids=pmhc_input_ids,
            pmhc_attention_mask=pmhc_attention_mask,
            tcr_pmhc_binding=binding    
        )

        loss = output2['binding_loss']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_binding_loss += output2['binding_loss'].item()
        del output2
        torch.cuda.empty_cache()

        epoch_loss +=  epoch_binding_loss
        
    print(epoch_loss, epoch_binding_loss)


## Training both contact and binding task

In [None]:
def FLATTEN(x): return [i for s in x for i in s]
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for _ in range(24):
    epoch_loss = 0
    epoch_binding_loss = 0
    epoch_contact_loss = 0

    indices = list(map(lambda x: np.random.choice(range(len(x)), len(x), replace=False), agg_train['tcr_pseudosequence']))

    binding_data = []
    batch_binding_data =  FLATTEN([list(zip(np.array(x)[i], np.array(y)[i])) for x,y,i in zip(
        agg_train['tcr_pseudosequence'],
        agg_train['pmhc_pseudosequence'], 
        indices
    )])

    for tcr_pseudoseq,pmhc_pseudoseq in batch_binding_data:
        tcr_sequence, tcr_input_ids, tcr_attention_mask = tokenize_tcr_pseudo_sequence_to_fixed_length(tcr_pseudoseq)
        pmhc_sequence, pmhc_input_ids, pmhc_attention_mask = tokenize_to_fixed_length(pmhc_pseudoseq, 50)
        tcr_input_ids = np.array(tcr_input_ids)
        tcr_attention_mask = np.array(tcr_attention_mask)
        pmhc_input_ids = np.array(pmhc_input_ids)
        pmhc_attention_mask = np.array(pmhc_attention_mask)
        binding_data.append({
            "tcr_input_ids": tcr_input_ids,
            "tcr_attention_mask": tcr_attention_mask,
            "pmhc_sequence": pmhc_sequence,
            "pmhc_input_ids": pmhc_input_ids,
            "pmhc_attention_mask": pmhc_attention_mask,
            "binding": 1
        })

        avpseudo = ':'.join(tcr_pseudoseq.split(":")[:2])
        bvpseudo = ':'.join(tcr_pseudoseq.split(":")[3:5])
        if avpseudo in df_scrna_avpseudo_index:
            indices = df_scrna_avpseudo_index[avpseudo]
            if len(indices)  0:
                decoy_data = ds_scrna[int(np.random.choice(indices.flatten()))]
                tcr_input_ids = decoy_data['input_ids']
                tcr_attention_mask = decoy_data['attention_mask']
                binding_data.append({
                    "tcr_input_ids": tcr_input_ids,
                    "tcr_attention_mask": tcr_attention_mask,
                    "pmhc_sequence": pmhc_sequence,
                    "pmhc_input_ids": pmhc_input_ids,
                    "pmhc_attention_mask": pmhc_attention_mask,
                    "binding": 0
                })
        if bvpseudo in df_scrna_bvpseudo_index:
            indices = df_scrna_bvpseudo_index[bvpseudo]
            if len(indices) > 0:
                decoy_data = ds_scrna[int(np.random.choice(indices.flatten()))]
                tcr_input_ids = decoy_data['input_ids']
                tcr_attention_mask = decoy_data['attention_mask']
                binding_data.append({
                    "tcr_input_ids": tcr_input_ids,
                    "tcr_attention_mask": tcr_attention_mask,
                    "pmhc_sequence": pmhc_sequence,
                    "pmhc_input_ids": pmhc_input_ids,
                    "pmhc_attention_mask": pmhc_attention_mask,
                    "binding": 0
                })
        if avpseudo in df_scrna_avpseudo_index and bvpseudo in df_scrna_bvpseudo_index:
            indices = list(set(df_scrna_avpseudo_index[avpseudo]).intersection(set(df_scrna_bvpseudo_index[bvpseudo])))
            if len(indices) > 0:
                decoy_data = ds_scrna[int(np.random.choice(indices))]
                tcr_input_ids = decoy_data['input_ids']
                tcr_attention_mask = decoy_data['attention_mask']
                binding_data.append({
                    "tcr_input_ids": tcr_input_ids,
                    "tcr_attention_mask": tcr_attention_mask,
                    "pmhc_sequence": pmhc_sequence,
                    "pmhc_input_ids": pmhc_input_ids,
                    "pmhc_attention_mask": pmhc_attention_mask,
                    "binding": 0
                })

    batch_decoy = ds_scrna[np.random.choice(np.arange(len(ds_scrna)), len(batch_binding_data), replace=False)]

    for tcr_input_ids, tcr_attention_mask, pmhc_pseudoseq in zip(batch_decoy['input_ids'],batch_decoy['attention_mask'],list(map(lambda x: x[1], batch_binding_data))):
        tokenize_to_fixed_length(pmhc_pseudoseq, 50)
        binding_data.append({
            "tcr_input_ids": tcr_input_ids,
            "tcr_attention_mask": tcr_attention_mask,
            "pmhc_sequence": pmhc_sequence,
            "pmhc_input_ids": pmhc_input_ids,
            "pmhc_attention_mask": pmhc_attention_mask,
            "binding": 0
        })
    batch_binding_dataset = datasets.Dataset.from_pandas(pd.DataFrame(binding_data))
    batch_binding_dataset = batch_binding_dataset.shuffle()

    n_per_batch = min(10, len(batch_binding_dataset) // len(contact_dataset))
    batch_contact_dataset = np.array(contact_dataset)[
        np.vstack(
            [
                np.random.choice(range(len(contact_dataset)), 1)
                for _ in range(len(batch_binding_dataset) // n_per_batch)
            ]
        ).flatten()
    ]
    for i in tqdm.trange(0, len(contact_dataset)):
        tcr_input_ids = torch.from_numpy(contact_dataset[i]['tcr_input_ids']).unsqueeze(0).to("cuda:0")
        tcr_attention_mask = torch.from_numpy(contact_dataset[i]['tcr_attention_mask']).unsqueeze(0).to(torch.float32).to("cuda:0")
        pmhc_input_ids = torch.from_numpy(contact_dataset[0]['pmhc_input_ids']).unsqueeze(0).to("cuda:0")
        pmhc_attention_mask = torch.from_numpy(contact_dataset[i]['pmhc_attention_mask']).unsqueeze(0).to(torch.float32).to("cuda:0")
        distogram = torch.from_numpy(contact_dataset[i]['distogram']).unsqueeze(0).to("cuda:0")

        output1 = model(
            tcr_input_ids=tcr_input_ids,
            tcr_attention_mask=tcr_attention_mask,
            pmhc_input_ids=pmhc_input_ids,
            pmhc_attention_mask=pmhc_attention_mask,
            tcr_pmhc_distogram=distogram,
        )

        loss = output1['contact_loss'] 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_contact_loss += output1['contact_loss'].item()
        del output1
        torch.cuda.empty_cache()

        tcr_input_ids = torch.tensor(batch_binding_dataset[i*n_per_batch:i*n_per_batch+n_per_batch]['tcr_input_ids']).to("cuda:0")
        tcr_attention_mask = torch.tensor(batch_binding_dataset[i*n_per_batch:i*n_per_batch+n_per_batch]['tcr_attention_mask']).to(torch.float32).to("cuda:0")
        pmhc_input_ids = torch.tensor(batch_binding_dataset[i*n_per_batch:i*n_per_batch+n_per_batch]['pmhc_input_ids']).to("cuda:0")
        pmhc_attention_mask = torch.tensor(batch_binding_dataset[i*n_per_batch:i*n_per_batch+n_per_batch]['pmhc_attention_mask']).to(torch.float32).to("cuda:0")
        binding =  torch.tensor(batch_binding_dataset[i*n_per_batch:i*n_per_batch+n_per_batch]['binding']).to("cuda:0").to(torch.float32).unsqueeze(1)

        output2 = model(
            tcr_input_ids=tcr_input_ids,
            tcr_attention_mask=tcr_attention_mask,
            pmhc_input_ids=pmhc_input_ids,
            pmhc_attention_mask=pmhc_attention_mask,
            tcr_pmhc_binding=binding    
        )

        loss = output2['binding_loss']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_binding_loss += output2['binding_loss'].item()
        del output2
        torch.cuda.empty_cache()

        epoch_loss += epoch_contact_loss + epoch_binding_loss

    print(epoch_loss, epoch_contact_loss, epoch_binding_loss)

In [None]:
torch.save(model.state_dict(), "./20240418_tmp.ckpt")

In [None]:
model.load_state_dict(torch.load("./20240418_tmp.ckpt"))

In [None]:
''.join([trab_tokenizer_for_pseudosequence.id_to_token(x) for x in batch_binding_dataset[i*n_per_batch+2]['tcr_input_ids']])

In [None]:
output1['contact_output'][1].detach().cpu().numpy().shape

In [None]:
fig,axes=createSubplots(1,2,figsize=(10,5))
contact_groud_truth = distogram[0].topk(1)[1].squeeze().detach().cpu().numpy()
contact_groud_truth[output1['contact_output'][1].detach().cpu().numpy()[0]==0]=64
sns.heatmap(
    contact_groud_truth,
    ax=axes[0]
)
contact_prediction = output1['contact_output'][2][0].topk(1)[1].squeeze().detach().cpu().numpy()
# contact_prediction[output1['contact_output'][1].detach().cpu().numpy()[0]==0]=64
sns.heatmap(
    contact_prediction,
    ax=axes[1]
)
plt.show()

## Train Dataset Evaluation

In [None]:
all_binding_predictions = []

indices = list(map(lambda x: np.random.choice(range(len(x)), len(x), replace=False), agg_train['tcr_pseudosequence']))
binding_data = []
batch_binding_data =  FLATTEN([list(zip(np.array(x)[i], np.array(y)[i])) for x,y,i in zip(
    agg_train['tcr_pseudosequence'],
    agg_train['pmhc_pseudosequence'], 
    indices
)])

for tcr_pseudoseq,pmhc_pseudoseq in batch_binding_data:
    tcr_sequence, tcr_input_ids, tcr_attention_mask = tokenize_tcr_pseudo_sequence_to_fixed_length(tcr_pseudoseq)
    pmhc_sequence, pmhc_input_ids, pmhc_attention_mask = tokenize_to_fixed_length(pmhc_pseudoseq, 50)
    tcr_input_ids = np.array(tcr_input_ids)
    tcr_attention_mask = np.array(tcr_attention_mask)
    pmhc_input_ids = np.array(pmhc_input_ids)
    pmhc_attention_mask = np.array(pmhc_attention_mask)
    binding_data.append({
        "tcr_input_ids": tcr_input_ids,
        "tcr_attention_mask": tcr_attention_mask,
        "pmhc_sequence": pmhc_sequence,
        "pmhc_input_ids": pmhc_input_ids,
        "pmhc_attention_mask": pmhc_attention_mask,
        "binding": 1
    })

    avpseudo = ':'.join(tcr_pseudoseq.split(":")[:2])
    bvpseudo = ':'.join(tcr_pseudoseq.split(":")[3:5])
    if avpseudo in df_scrna_avpseudo_index:
        indices = df_scrna_avpseudo_index[avpseudo]
        if len(indices) > 0:
            decoy_data = ds_scrna[int(np.random.choice(indices.flatten()))]
            tcr_input_ids = decoy_data['input_ids']
            tcr_attention_mask = decoy_data['attention_mask']
            binding_data.append({
                "tcr_input_ids": tcr_input_ids,
                "tcr_attention_mask": tcr_attention_mask,
                "pmhc_sequence": pmhc_sequence,
                "pmhc_input_ids": pmhc_input_ids,
                "pmhc_attention_mask": pmhc_attention_mask,
                "binding": 0
            })
    if bvpseudo in df_scrna_bvpseudo_index:
        indices = df_scrna_bvpseudo_index[bvpseudo]
        if len(indices) > 0:
            decoy_data = ds_scrna[int(np.random.choice(indices.flatten()))]
            tcr_input_ids = decoy_data['input_ids']
            tcr_attention_mask = decoy_data['attention_mask']
            binding_data.append({
                "tcr_input_ids": tcr_input_ids,
                "tcr_attention_mask": tcr_attention_mask,
                "pmhc_sequence": pmhc_sequence,
                "pmhc_input_ids": pmhc_input_ids,
                "pmhc_attention_mask": pmhc_attention_mask,
                "binding": 0
            })
    if avpseudo in df_scrna_avpseudo_index and bvpseudo in df_scrna_bvpseudo_index:
        indices = list(set(df_scrna_avpseudo_index[avpseudo]).intersection(set(df_scrna_bvpseudo_index[bvpseudo])))
        if len(indices) > 0:
            decoy_data = ds_scrna[int(np.random.choice(indices))]
            tcr_input_ids = decoy_data['input_ids']
            tcr_attention_mask = decoy_data['attention_mask']
            binding_data.append({
                "tcr_input_ids": tcr_input_ids,
                "tcr_attention_mask": tcr_attention_mask,
                "pmhc_sequence": pmhc_sequence,
                "pmhc_input_ids": pmhc_input_ids,
                "pmhc_attention_mask": pmhc_attention_mask,
                "binding": 0
            })

batch_decoy = ds_scrna[np.random.choice(np.arange(len(ds_scrna)), len(batch_binding_data), replace=False)]

for tcr_input_ids, tcr_attention_mask, pmhc_pseudoseq in zip(batch_decoy['input_ids'],batch_decoy['attention_mask'],list(map(lambda x: x[1], batch_binding_data))):
    tokenize_to_fixed_length(pmhc_pseudoseq, 50)
    binding_data.append({
        "tcr_input_ids": tcr_input_ids,
        "tcr_attention_mask": tcr_attention_mask,
        "pmhc_sequence": pmhc_sequence,
        "pmhc_input_ids": pmhc_input_ids,
        "pmhc_attention_mask": pmhc_attention_mask,
        "binding": 0
    })
batch_binding_dataset = datasets.Dataset.from_pandas(pd.DataFrame(binding_data))
batch_binding_dataset = batch_binding_dataset.shuffle()

import tqdm

with torch.no_grad():
    for i in tqdm.trange(0, len(batch_binding_dataset), n_per_batch):
        tcr_input_ids = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['tcr_input_ids']).to("cuda:0")
        tcr_attention_mask = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['tcr_attention_mask']).to(torch.float32).to("cuda:0")
        pmhc_input_ids = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['pmhc_input_ids']).to("cuda:0")
        pmhc_attention_mask = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['pmhc_attention_mask']).to(torch.float32).to("cuda:0")
        binding =  torch.tensor(batch_binding_dataset[i:i+n_per_batch]['binding']).to("cuda:0").to(torch.float32).unsqueeze(1)

        output2 = model(
            tcr_input_ids=tcr_input_ids,
            tcr_attention_mask=tcr_attention_mask,
            pmhc_input_ids=pmhc_input_ids,
            pmhc_attention_mask=pmhc_attention_mask,
            tcr_pmhc_binding=binding
        )
        all_binding_predictions.append(
            output2['contact_output'][3].detach().cpu().numpy()
        )


In [None]:

pd.DataFrame([
    (np.vstack(all_binding_predictions).flatten() > 0) == (np.array(batch_binding_dataset['binding']) == 1),
    list(map(lambda x: x.split(":")[-1].split(".")[0], batch_binding_dataset['pmhc_sequence']))
], index=['correct','peptide']).T.groupby("peptide").agg({"correct": Counter})

## Test Dataset Evaluation

In [None]:
all_binding_predictions = []

indices = list(map(lambda x: np.random.choice(range(len(x)), len(x), replace=False), agg_test['tcr_pseudosequence']))
binding_data = []
batch_binding_data =  FLATTEN([list(zip(np.array(x)[i], np.array(y)[i])) for x,y,i in zip(
    agg_test['tcr_pseudosequence'],
    agg_test['pmhc_pseudosequence'], 
    indices
)])

for tcr_pseudoseq,pmhc_pseudoseq in batch_binding_data:
    tcr_sequence, tcr_input_ids, tcr_attention_mask = tokenize_tcr_pseudo_sequence_to_fixed_length(tcr_pseudoseq)
    pmhc_sequence, pmhc_input_ids, pmhc_attention_mask = tokenize_to_fixed_length(pmhc_pseudoseq, 50)
    tcr_input_ids = np.array(tcr_input_ids)
    tcr_attention_mask = np.array(tcr_attention_mask)
    pmhc_input_ids = np.array(pmhc_input_ids)
    pmhc_attention_mask = np.array(pmhc_attention_mask)
    binding_data.append({
        "tcr_input_ids": tcr_input_ids,
        "tcr_attention_mask": tcr_attention_mask,
        "pmhc_sequence": pmhc_sequence,
        "pmhc_input_ids": pmhc_input_ids,
        "pmhc_attention_mask": pmhc_attention_mask,
        "binding": 1
    })

    avpseudo = ':'.join(tcr_pseudoseq.split(":")[:2])
    bvpseudo = ':'.join(tcr_pseudoseq.split(":")[3:5])
    if avpseudo in df_scrna_avpseudo_index:
        indices = df_scrna_avpseudo_index[avpseudo]
        if len(indices) > 0:
            decoy_data = ds_scrna[int(np.random.choice(indices.flatten()))]
            tcr_input_ids = decoy_data['input_ids']
            tcr_attention_mask = decoy_data['attention_mask']
            binding_data.append({
                "tcr_input_ids": tcr_input_ids,
                "tcr_attention_mask": tcr_attention_mask,
                "pmhc_sequence": pmhc_sequence,
                "pmhc_input_ids": pmhc_input_ids,
                "pmhc_attention_mask": pmhc_attention_mask,
                "binding": 0
            })
    if bvpseudo in df_scrna_bvpseudo_index:
        indices = df_scrna_bvpseudo_index[bvpseudo]
        if len(indices) > 0:
            decoy_data = ds_scrna[int(np.random.choice(indices.flatten()))]
            tcr_input_ids = decoy_data['input_ids']
            tcr_attention_mask = decoy_data['attention_mask']
            binding_data.append({
                "tcr_input_ids": tcr_input_ids,
                "tcr_attention_mask": tcr_attention_mask,
                "pmhc_sequence": pmhc_sequence,
                "pmhc_input_ids": pmhc_input_ids,
                "pmhc_attention_mask": pmhc_attention_mask,
                "binding": 0
            })
    if avpseudo in df_scrna_avpseudo_index and bvpseudo in df_scrna_bvpseudo_index:
        indices = list(set(df_scrna_avpseudo_index[avpseudo]).intersection(set(df_scrna_bvpseudo_index[bvpseudo])))
        if len(indices) > 0:
            decoy_data = ds_scrna[int(np.random.choice(indices))]
            tcr_input_ids = decoy_data['input_ids']
            tcr_attention_mask = decoy_data['attention_mask']
            binding_data.append({
                "tcr_input_ids": tcr_input_ids,
                "tcr_attention_mask": tcr_attention_mask,
                "pmhc_sequence": pmhc_sequence,
                "pmhc_input_ids": pmhc_input_ids,
                "pmhc_attention_mask": pmhc_attention_mask,
                "binding": 0
            })

batch_decoy = ds_scrna[np.random.choice(np.arange(len(ds_scrna)), len(batch_binding_data), replace=False)]

for tcr_input_ids, tcr_attention_mask, pmhc_pseudoseq in zip(batch_decoy['input_ids'],batch_decoy['attention_mask'],list(map(lambda x: x[1], batch_binding_data))):
    tokenize_to_fixed_length(pmhc_pseudoseq, 50)
    binding_data.append({
        "tcr_input_ids": tcr_input_ids,
        "tcr_attention_mask": tcr_attention_mask,
        "pmhc_sequence": pmhc_sequence,
        "pmhc_input_ids": pmhc_input_ids,
        "pmhc_attention_mask": pmhc_attention_mask,
        "binding": 0
    })
batch_binding_dataset = datasets.Dataset.from_pandas(pd.DataFrame(binding_data))
batch_binding_dataset = batch_binding_dataset.shuffle()

import tqdm

with torch.no_grad():
    for i in tqdm.trange(0, len(batch_binding_dataset), n_per_batch):
        tcr_input_ids = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['tcr_input_ids']).to("cuda:0")
        tcr_attention_mask = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['tcr_attention_mask']).to(torch.float32).to("cuda:0")
        pmhc_input_ids = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['pmhc_input_ids']).to("cuda:0")
        pmhc_attention_mask = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['pmhc_attention_mask']).to(torch.float32).to("cuda:0")
        binding =  torch.tensor(batch_binding_dataset[i:i+n_per_batch]['binding']).to("cuda:0").to(torch.float32).unsqueeze(1)

        output2 = model(
            tcr_input_ids=tcr_input_ids,
            tcr_attention_mask=tcr_attention_mask,
            pmhc_input_ids=pmhc_input_ids,
            pmhc_attention_mask=pmhc_attention_mask,
            tcr_pmhc_binding=binding
        )
        all_binding_predictions.append(
            output2['contact_output'][3].detach().cpu().numpy()
        )

pd.DataFrame([
    (np.vstack(all_binding_predictions).flatten() > 0) == (np.array(batch_binding_dataset['binding']) == 1),
    list(map(lambda x: x.split(":")[-1].split(".")[0], batch_binding_dataset['pmhc_sequence']))
], index=['correct','peptide']).T.groupby("peptide").agg({"correct": Counter})

## Extract attention scores

In [None]:
all_attentions_tcr = []
all_attentions_pmhc = []
with torch.no_grad():
    for i in tqdm.trange(0, len(contact_dataset)):
        tcr_input_ids = torch.from_numpy(contact_dataset[i]['tcr_input_ids']).unsqueeze(0).to("cuda:0")
        tcr_attention_mask = torch.from_numpy(contact_dataset[i]['tcr_attention_mask']).unsqueeze(0).to(torch.float32).to("cuda:0")
        pmhc_input_ids = torch.from_numpy(contact_dataset[i]['pmhc_input_ids']).unsqueeze(0).to("cuda:0")
        pmhc_attention_mask = torch.from_numpy(contact_dataset[i]['pmhc_attention_mask']).unsqueeze(0).to(torch.float32).to("cuda:0")
        distogram = torch.from_numpy(contact_dataset[i]['distogram']).unsqueeze(0).to("cuda:0")
        output1 = model(
            tcr_input_ids=tcr_input_ids,
            tcr_attention_mask=tcr_attention_mask,
            pmhc_input_ids=pmhc_input_ids,
            pmhc_attention_mask=pmhc_attention_mask,
            tcr_pmhc_distogram=distogram,
            output_triangular_attentions=True
        )
        all_attentions_tcr.append([output1['contact_output'][-1][j][1].detach().cpu() for j in range(config.num_hidden_layers)])
        all_attentions_pmhc.append([output1['contact_output'][-1][j][0].detach().cpu() for j in range(config.num_hidden_layers)])


all_attentions_tcr = [torch.vstack([ x[i] for x in all_attentions_tcr]) for i in range(config.num_hidden_layers)]
all_attentions_pmhc = [torch.vstack([ x[i] for x in all_attentions_pmhc]) for i in range(config.num_hidden_layers)]
# Layer, Batch, pMHC pseudo, Head, TCR pseudo, TCR pseudo
all_attentions_tcr = torch.vstack(list(map(lambda x: x.unsqueeze(0), all_attentions_tcr)))
# Layer, Batch, TCR pseudo, Head, pMHC pseudo, pMHC pseudo
all_attentions_pmhc = torch.vstack(list(map(lambda x: x.unsqueeze(0), all_attentions_pmhc)))

# all_attentions[0].mean([0,1,2,3,])


In [None]:
import tqdm
binding_data = []
n_per_batch = 10
for tcr_pseudoseqs,pmhc_pseudoseqs in zip(agg['tcr_pseudosequence'], agg['pmhc_pseudosequence']):
    for tcr_pseudoseq,pmhc_pseudoseq in zip(tcr_pseudoseqs,pmhc_pseudoseqs):
        tcr_sequence, tcr_input_ids, tcr_attention_mask = tokenize_tcr_pseudo_sequence_to_fixed_length(tcr_pseudoseq)
        pmhc_sequence, pmhc_input_ids, pmhc_attention_mask = tokenize_to_fixed_length(pmhc_pseudoseq, 50)
        tcr_input_ids = np.array(tcr_input_ids)
        tcr_attention_mask = np.array(tcr_attention_mask)
        pmhc_input_ids = np.array(pmhc_input_ids)
        pmhc_attention_mask = np.array(pmhc_attention_mask)
        binding_data.append({
            "tcr_input_ids": tcr_input_ids,
            "tcr_attention_mask": tcr_attention_mask,
            "pmhc_sequence": pmhc_sequence,
            "pmhc_input_ids": pmhc_input_ids,
            "pmhc_attention_mask": pmhc_attention_mask,
            "binding": 1
        })

all_batch_binding_dataset = datasets.Dataset.from_pandas(pd.DataFrame(binding_data))

for pmhc in np.unique(all_batch_binding_dataset['pmhc_sequence']):
    batch_binding_dataset = all_batch_binding_dataset.select( np.argwhere(np.array(list(map(lambda x: x['pmhc_sequence'] == pmhc, all_batch_binding_dataset)) )))

    all_attentions = []
    with torch.no_grad():
        for i in tqdm.trange(0, len(batch_binding_dataset), n_per_batch):
            tcr_input_ids = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['tcr_input_ids']).to("cuda:0")
            tcr_attention_mask = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['tcr_attention_mask']).to(torch.float32).to("cuda:0")
            pmhc_input_ids = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['pmhc_input_ids']).to("cuda:0")
            pmhc_attention_mask = torch.tensor(batch_binding_dataset[i:i+n_per_batch]['pmhc_attention_mask']).to(torch.float32).to("cuda:0")
            binding =  torch.tensor(batch_binding_dataset[i:i+n_per_batch]['binding']).to("cuda:0").to(torch.float32).unsqueeze(1)
            output2 = model(
                tcr_input_ids=tcr_input_ids,
                tcr_attention_mask=tcr_attention_mask,
                pmhc_input_ids=pmhc_input_ids,
                pmhc_attention_mask=pmhc_attention_mask,
                tcr_pmhc_binding=binding,
                output_triangular_attentions=True
            )
            attention = torch.vstack([output2['contact_output'][-1][j][1].detach().cpu().unsqueeze(0) for j in range(config.num_hidden_layers)])
            np.save(f'./data/attention_score/{pmhc.split(":")[-1].replace(".","")}_{i}.npy', attention.detach().cpu().numpy())
            del output2
            torch.cuda.empty_cache()
