In [None]:
from umap import UMAP
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from tqdm.autonotebook import tqdm
import pandas as pd
import ablang
from pathlib import Path
import os
from time import time
from itertools import pairwise
from multiprocessing.pool import ThreadPool
import torch
from time import sleep

In [None]:
cwd = Path.cwd()

In [None]:
data_path = Path.cwd().parent / 'data' / 'antibody' / 'covid_alphaseq'
data_file = data_path / 'MITLL_AAlphaBio_Ab_Binding_dataset.csv'

In [None]:
savedir = data_path

In [None]:
df_raw = pd.read_csv(data_file)
seqs = pd.read_csv(data_path / "sequence_uuids.csv", index_col=0)
hcs = pd.read_csv(data_path / "hc_uuids.csv", index_col=0)
lcs = pd.read_csv(data_path / "lc_uuids.csv", index_col=0)
df=df_raw.merge(seqs, on="Sequence")
df=df.merge(hcs, on="HC")
df=df.merge(lcs, on="LC")

In [None]:
heavy_ablang = ablang.pretrained("heavy", device=f'cuda:0', ncpu=2)
heavy_ablang.freeze()

In [None]:
def compute_embeddings(model, hc, sz=200):
    with torch.no_grad():
        embeddings = model(hc, mode='seqcoding', splitSize=sz).astype(np.float32)
    return embeddings

In [None]:
end = 0
size = 50000
start = -size
finished = False
batch_idx = 0
df_embeddings = pd.DataFrame()
pbar = tqdm(total=len(df))
pbar.set_description_str(f"Batch {batch_idx}: {100*end/len(df):.2f}%")

while not finished:
    start+=size
    end+=size
    batch_idx+=1
    if end > len(df):
        end = len(df)
        finished = True
        break

    batch_df = df.iloc[start:end]

    embeddings=compute_embeddings(heavy_ablang, batch_df.HC)
    batch_df = batch_df.assign(hc_ablang_embedding = list(embeddings))
    df_embeddings = pd.concat([df_embeddings, batch_df])

    pbar.set_description_str(f"Batch {batch_idx}: {100*end/len(df):.2f}%")
    pbar.update(n=end-start)

df_embeddings.to_pickle(savedir / "df_ablang_embeddings_hc.pkl")
pbar.close()