In [None]:
import itertools
from contextlib import ExitStack

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from immune_embeddings.data import get_data_root
from immune_embeddings.data.covid_to_esm import EmbedTokensESM
from immune_embeddings.models.embeddings.esm import ESMEmbeddingModel

In [None]:
es = ExitStack()
es.enter_context(torch.no_grad())

In [None]:
data_dir = get_data_root() / "tcr" / "raw_cdr"

In [None]:
data_df = pd.read_csv(data_dir / "dresden_tcr.csv", index_col=0)

In [None]:
model_size = "3B"
model_depth = 36
repr_layer = 36
model_id = f"esm2_t{model_depth}_{model_size}_UR50D"

In [None]:
model = ESMEmbeddingModel(model_id=model_id,repr_layer=repr_layer,fixed_size=True,freeze_weights=True)
model = model.to("cuda")

In [None]:
tokenizer = EmbedTokensESM(model_id=model_id)

In [None]:
def compute_embeddings(sequence_col):
    example_sequences = list(enumerate(data_df[sequence_col].unique()))
    dl = DataLoader(example_sequences,batch_size=128,num_workers=4,prefetch_factor=3,shuffle=False)

    data = {"seq": [], "embedding": []}

    for batch in tqdm(dl):
        ids = batch[0]
        sequences = batch[1]
        batch_inputs = {k: v.to("cuda") for k, v in tokenizer(protein_sequences=sequences, sequence_ids=ids).items()}
        embeddings = model(batch_inputs)
        data["seq"].append(sequences)
        data["embedding"].append(embeddings.cpu().numpy())

    sequences = list(itertools.chain.from_iterable(data["seq"]))
    embeddings = np.concatenate(data["embedding"], axis=0)
    out_df = pd.DataFrame({sequence_col: sequences, f"{sequence_col}_ESM": list(embeddings)})
    return out_df

In [None]:
trb_cdr3_df_esm = compute_embeddings("TRB_CDR3")
tra_cdr3_df_esm = compute_embeddings("TRA_CDR3")
data_embeddings_df = data_df.merge(trb_cdr3_df_esm, on="TRB_CDR3").merge(tra_cdr3_df_esm, on="TRA_CDR3")

In [None]:
data_embeddings_df.to_pickle(data_dir / f"dresden_esm_{model_size}_L{repr_layer}.pkl")