In [3]:
IS_COLAB = False

In [4]:
if IS_COLAB:
    from google.colab import drive

    drive.mount('/content/drive')
    %ls
    %cd drive/MyDrive/dsmp-2024-groupol1/

In [None]:
import pandas as pd
df = pd.read_csv('./data/vdjdb.txt', sep="\t")

In [None]:
selected_features = df[['gene','cdr3','v.segm','j.segm','species','mhc.a','mhc.b','mhc.class','antigen.epitope','antigen.species','vdjdb.score']]
 
# Select all human data
human_data = selected_features[(selected_features['species'] == 'HomoSapiens') & (selected_features['vdjdb.score'] > 0)]

# Drop duplicate rows
human_data = human_data.drop_duplicates()

# Delete rows with null values
human_data  = human_data.dropna()
# Print all data
human_data.head()    


In [None]:
TRB = human_data[human_data['gene'] =='TRB']
# rename the columns for our beta chain matrix calculation
beta_chains = TRB[['cdr3', 'v.segm', 'j.segm', 'antigen.epitope']]
beta_chains.rename(columns={'cdr3':'cdr3_b_aa','v.segm':'v_b_gene', 'j.segm':'j_b_gene'}, inplace=True)
beta_chains

In [None]:
!pip install tcrdist3
!pip install umap-learn
!pip install umap-learn[plot]

In [None]:
import os
from tcrdist.repertoire import TCRrep
import umap
import umap.plot
import matplotlib.pyplot as plt

def calculate_dist_and_umap(df: pd.DataFrame, 
                            chain: str, 
                            gene: str) -> pd.DataFrame:
    
  tr = TCRrep(cell_df = df,
            organism = 'human',
            chains = [chain],
            db_file = 'alphabeta_gammadelta_db.tsv')

  if chain == 'beta':
    distance_matrix = pd.concat([pd.DataFrame(tr.pw_cdr3_b_aa), tr.clone_df[gene]], axis = 1)
  elif chain == 'alpha':
    distance_matrix = pd.concat([pd.DataFrame(tr.pw_cdr3_a_aa), tr.clone_df[gene]], axis = 1)

  value_counts_antigen = distance_matrix[gene].value_counts()
  top_10_value_counts = value_counts_antigen.nlargest(7)
  distance_matrix_filtered = distance_matrix[distance_matrix[gene].isin(top_10_value_counts.index)]

  distances_reduced = umap.UMAP(n_components = 2, n_neighbors = 100).fit(distance_matrix_filtered.iloc[:, :-1].values)

  output_dir = 'visualisations'

  f = umap.plot.points(distances_reduced, labels=distance_matrix_filtered[gene])
  f.set_xlabel('UMAP Dimension 1', fontsize=10)
  f.set_ylabel('UMAP Dimension 2', fontsize=10)
  f.set_title(f'UMAP Visualization of {chain}', fontsize=12)

  # Save the figure
  if not os.path.exists(output_dir):
    os.makedirs(output_dir)

  f.get_figure().savefig(f'{output_dir}/{chain}_chain_umap.png')
  return distance_matrix

In [None]:
beta_dist_matrix = calculate_dist_and_umap(beta_chains, 'beta', 'antigen.epitope')

In [None]:
# Do the same as above but for alpha chains
TRA = human_data[human_data['gene'] =='TRA']
alpha_chains = TRA[['cdr3', 'v.segm', 'j.segm', 'antigen.epitope']].copy()
alpha_chains.rename(columns={'cdr3':'cdr3_a_aa','v.segm':'v_a_gene', 'j.segm':'j_a_gene'}, inplace=True)
print(alpha_chains.columns)

alpha_df = calculate_dist_and_umap(alpha_chains, 'alpha', 'antigen.epitope')

In [None]:
import torch

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
from transformers import BertModel, BertTokenizer

model_name = 'wukevin/tcr-bert'

model = BertModel.from_pretrained(model_name, add_pooling_layer=False).to(device)

In [None]:
from typing import  Sequence, Any
from math import floor

PAD = "$"
MASK = "."
UNK = "?"
SEP = "|"
CLS = "*"

def is_whitespaced(seq: str) -> bool:
    tok = list(seq)
    spaces = [t for t in tok if t.isspace()]
    if len(spaces) == floor(len(seq) / 2):
        return True
    return False

def get_pretrained_bert_tokenizer(path: str):
    """Get the pretrained BERT tokenizer from given path"""
    tok = BertTokenizer.from_pretrained(
        path,
        do_basic_tokenize=False,
        do_lower_case=False,
        tokenize_chinese_chars=False,
        unk_token=UNK,
        sep_token=SEP,
        pad_token=PAD,
        cls_token=CLS,
        mask_token=MASK,
        padding_side="right",
    )
    return tok

def chunkify(x: Sequence[Any], chunk_size: int = 128):
    retval = [x[i : i + chunk_size] for i in range(0, len(x), chunk_size)]
    return retval

def insert_whitespace(seq: str) -> str:
    return " ".join(list(seq))

In [None]:
from transformers import BertModel, BertTokenizer

model_name = 'wukevin/tcr-bert'

model = BertModel.from_pretrained(model_name, add_pooling_layer=False).to(device)
model_tokenizer = get_pretrained_bert_tokenizer(model_name)

In [None]:
chains = beta_chains['cdr3_b_aa'].tolist()

In [None]:
from itertools import zip_longest
import numpy as np

layers = [-1]

seqs = [s if is_whitespaced(s) else insert_whitespace(s) for s in chains]

chunks = chunkify(seqs, 1)
chunks_pair = [None]
chunks_zipped = list(zip_longest(chunks, chunks_pair))
embeddings = []

with torch.no_grad():
      for seq_chunk in chunks_zipped:
          encoded = model_tokenizer(
              *seq_chunk, padding="max_length", max_length=64, return_tensors="pt"
          )
          # manually calculated mask lengths
          # temp = [sum([len(p.split()) for p in pair]) + 3 for pair in zip(*seq_chunk)]
          input_mask = encoded["attention_mask"].numpy()
          encoded = {k: v.to(device) for k, v in encoded.items()}
          # encoded contains input attention mask of (batch, seq_len)
          x = model.forward(**encoded, output_hidden_states=True, output_attentions=True)

          for i in range(len(seq_chunk[0])):
                e = []
                for l in layers:
                    # Select the l-th hidden layer for the i-th example
                    h = (x.hidden_states[l][i].cpu().numpy().astype(np.float64))
                    if seq_chunk[1] is None:
                      seq_len = len(seq_chunk[0][i].split())
                    seq_hidden = h[1 : 1 + seq_len]
                    e.append(seq_hidden.mean(axis=0))

                e = np.hstack(e)
                assert len(e.shape) == 1
                embeddings.append(e)

if len(embeddings[0].shape) == 1:
    embeddings = np.stack(embeddings)
else:
    embeddings = np.vstack(embeddings)


In [None]:
beta_chains.reset_index(inplace= True)
embedding_df = pd.concat([pd.DataFrame(embeddings), beta_chains['antigen.epitope']], axis = 1)
value_counts_antigen = embedding_df['antigen.epitope'].value_counts()
top_10_value_counts = value_counts_antigen.nlargest(7)
embedding_df_filtered = embedding_df[embedding_df['antigen.epitope'].isin(top_10_value_counts.index)]
print(embedding_df_filtered.shape)
distances_reduced = umap.UMAP(n_components = 2).fit(embedding_df_filtered.iloc[:, :-1].values)
distances_reduced
output_dir = 'visualisations'
f = umap.plot.points(distances_reduced, labels=embedding_df_filtered['antigen.epitope'])
f.set_xlabel('UMAP Dimension 1', fontsize=10)
f.set_ylabel('UMAP Dimension 2', fontsize=10)
f.set_title(f'Beta Chain by antigen specificity - Bert Embedding', fontsize=12)
f.get_figure().savefig(f'{output_dir}/beta_chain_umap_bert.png')


# Distances for paired alpha and beta pairs

In [None]:
# Let's get all the IDs for the TCRs (A & B pairs should have the same ID)
_ids = human_data['complex.id']
_ids

list_to_combine = []
checked_ids = []
def process_row(row):
    # check the complex id not already checked
    if row['complex.id'] not in checked_ids:
        # find matching rows
        matched_rows = human_data[human_data['complex.id'] == row['complex.id']]
        # should be two (some rows have only 1 match)
        if len(matched_rows) == 2:
            # get the tcra row
            tra_row = matched_rows.iloc[0]
            # get the tcrb row
            trb_row = matched_rows.iloc[1]
            # add to list as a combined row
            list_to_combine.append({'tcr_id_a':tra_row['complex.id'], 'tcr_id_b':trb_row['complex.id'], 
                                    'cdr3_a_aa': tra_row['cdr3'], 'cdr3_b_aa': trb_row['cdr3'],
                                    'v_b_gene' :trb_row['v.segm'],
                                    'j_b_gene':trb_row['j.segm'],
                                    'v_a_gene':tra_row['v.segm'],
                                    'j_a_gene':tra_row['j.segm'],
                                   })
            # we've checked this id now, so we need to make sure we don't have to check it again.
            checked_ids.append(row['complex.id'])
human_data.apply(process_row, axis=1)

In [None]:
# Create table representing paired tcr rows
paired_table = pd.DataFrame(list_to_combine)
paired_table

# Get paired distances

tr_paired = TCRrep(cell_df = paired_table, 
            organism = 'human', 
            chains = ['alpha','beta'], 
            db_file = 'alphabeta_gammadelta_db.tsv')

# get alpha chain distance calculations for paired tcrs and print them.
paired_matrix_alpha_chain = tr_paired.pw_alpha 
paired_alpha_distances = pd.DataFrame(paired_matrix_alpha_chain)
paired_alpha_distances

# get beta chain distance calculations for paired tcrs and print them.
paired_matrix_beta_chain = tr_paired.pw_beta
paired_beta_distances = pd.DataFrame(paired_matrix_beta_chain)
paired_beta_distances