In [None]:
from contextlib import ExitStack

import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
from torch.utils.data import DataLoader

from immune_embeddings.data import get_data_root
from immune_embeddings.data.titan_to_tcrbert import get_tcrbert_tokenizer
from tqdm.notebook import tqdm

import itertools

In [None]:
es = ExitStack()
es.enter_context(torch.no_grad())

In [None]:
data_dir = get_data_root() / "tcr" / "raw_cdr"

In [None]:
data_df = pd.read_csv(data_dir/  "dresden_tcr.csv", index_col=0)

In [None]:
tok = get_tcrbert_tokenizer()
tr = Transformer("wukevin/tcr-bert-mlm-only").to("cuda")
L=8
tr.auto_model.encoder.layer=tr.auto_model.encoder.layer[:L]
pooling=Pooling(word_embedding_dimension=tr.get_word_embedding_dimension(), pooling_mode='mean')

In [None]:
model = SentenceTransformer(modules=[tr,pooling]).to("cuda")

In [None]:
def compute_embeddings(sequence_col):
    example_sequences = list(data_df[sequence_col].unique())
    dl = DataLoader(example_sequences, batch_size=128, num_workers=4, prefetch_factor=3, shuffle=False)

    data = {'seq': [],'embedding': []}
    for batch in tqdm(dl):
        prepared_batch = [" ".join(list(s)) for s in batch]
        batch_inputs = {k: v.to("cuda") for k, v in tok(
            prepared_batch, return_tensors="pt", return_token_type_ids=False, padding='max_length', max_length=64).items() }
        embeddings = model(batch_inputs)["sentence_embedding"]
        data['seq'].append(batch)
        data['embedding'].append(embeddings.cpu().numpy())

    sequences = list(itertools.chain.from_iterable(data['seq']))
    embeddings = np.concatenate(data['embedding'],axis=0)
    out_df = pd.DataFrame({sequence_col: sequences, f'{sequence_col}_TCRBERT_L{L}': list(embeddings)})
    return out_df 

In [None]:
trb_cdr3_df = compute_embeddings("TRB_CDR3")

In [None]:
data_and_embeddings_df = data_df.merge(trb_cdr3_df, on="TRB_CDR3")

In [None]:
data_and_embeddings_df.to_pickle(data_dir / f"dresden_data_tcrbert_l{L}.pkl")