In [None]:
import numpy as np
from tqdm.autonotebook import tqdm
import pandas as pd
import ablang
from pathlib import Path
from time import time, sleep
import torch

In [None]:
cwd = Path.cwd()

In [None]:
data_path = Path.cwd().parent / 'data' /'bcr' / 'bcells_guikema' / 'sequences'
assert data_path.exists()

In [None]:
savedir = data_path

In [None]:
input_files = list(data_path.glob("*Nt_info.csv"))

In [None]:
def compute_embeddings(model, aa_seq, sz=10000):
    with torch.no_grad():
        embeddings = model(aa_seq, mode='seqcoding', splitSize=sz).astype(np.float32)
    return embeddings

In [None]:
def preprocess_df(data_df):
    data_df = data_df.loc[~data_df.Sequence_AA.str.contains("\*")]
    data_df = data_df.assign(Sequence_AA=data_df.Sequence_AA.str.replace("X","*"))
    return data_df

In [None]:
sz=10000
ncpu = 56

heavy_ablang = ablang.pretrained("heavy", device=f'cuda:0', ncpu=ncpu)
heavy_ablang.freeze()

for data_file in tqdm(input_files):
    df = preprocess_df(pd.read_csv(data_file, sep="\t", index_col="Sequence ID"))
    end = 0
    size = 10000
    start = -size
    finished = False
    batch_idx = 0
    df_embeddings = pd.DataFrame()

    while not finished:
        start+=size
        end+=size
        batch_idx+=1
        if end > len(df):
            end = len(df)
            finished = True

        batch_df = df.iloc[start:end]

        embeddings=compute_embeddings(heavy_ablang, batch_df.Sequence_AA,sz=sz)
        batch_df = batch_df.assign( 
            ablang_embedding = list(embeddings)
        )

        df_embeddings = pd.concat([df_embeddings, batch_df]) 
        save_subdir = savedir / data_file.stem
        save_subdir.mkdir(exist_ok=True, parents=True)
        batch_df.to_pickle(save_subdir / f"df_ablang_embeddings_{batch_idx}_.pkl" )

    df_embeddings.to_pickle(savedir / f"df_ablang_embeddings_{data_file.stem}.pkl")