In [5]:
import pandas as pd
import os
from transformers import AutoModel, AutoTokenizer
import torch
import numpy as np
import pickle

In [2]:
!git clone https://github.com/adaptyvbio/lemanic_2024.git

Cloning into 'lemanic_2024'...
remote: Enumerating objects: 41, done.[K
remote: Counting objects: 100% (41/41), done.[K
remote: Compressing objects: 100% (27/27), done.[K
remote: Total 41 (delta 16), reused 36 (delta 13), pack-reused 0[K
Receiving objects: 100% (41/41), 1.99 MiB | 4.17 MiB/s, done.
Resolving deltas: 100% (16/16), done.


In [3]:
#Get the data sets from literature and experiment
root = "lemanic_2024/data"
path_exp_train = root + "/experiment_train.csv"
path_exp_test =  root + "/experiment_test.csv"
path_lit_train = root + "/literature_train.csv"
path_lit_test =  root + "/literature_test.csv"

experiment_train_df = pd.read_csv(path_exp_train)
experiment_test_df = pd.read_csv(path_exp_test)
literature_train_df = pd.read_csv(path_lit_train)
literature_test_df = pd.read_csv(path_lit_test)

In [4]:
def find_start_indices(df):
    start_indices = []
    for index, row in df.iterrows():
        cdrh3 = row['CDRH3']
        vhorvhh = row['VHorVHH']
        start_index = vhorvhh.find(cdrh3)
        if start_index == -1:
          start_index = float('nan')
        start_indices.append(start_index)
    return start_indices

# literature_train
literature_train_df["start_index_CDRH3"] = find_start_indices(literature_train_df)
experiment_train_df["start_index_CDRH3"] = find_start_indices(experiment_train_df)

experiment_test_df["start_index_CDRH3"] = find_start_indices(experiment_test_df)
literature_test_df["start_index_CDRH3"] = find_start_indices(literature_test_df)

#Stop index
experiment_train_df['stop_index_CDRH3'] = experiment_train_df['start_index_CDRH3'] + experiment_train_df['CDRH3'].str.len()
literature_train_df['stop_index_CDRH3'] = literature_train_df['start_index_CDRH3'] + literature_train_df['CDRH3'].str.len()

literature_test_df['stop_index_CDRH3'] = literature_test_df['start_index_CDRH3'] + literature_test_df['CDRH3'].str.len()
experiment_test_df['stop_index_CDRH3'] = experiment_test_df['start_index_CDRH3'] + experiment_test_df['CDRH3'].str.len()

In [14]:
#tokenizer and model
max_len = 157
tokenizer_light = AutoTokenizer.from_pretrained('qilowoq/AbLang_light', truncation=True, max_length=max_len)
model_light = AutoModel.from_pretrained('qilowoq/AbLang_light', trust_remote_code=True)

tokenizer_heavy = AutoTokenizer.from_pretrained('qilowoq/AbLang_heavy', truncation=True, max_length=max_len)
model_heavy = AutoModel.from_pretrained('qilowoq/AbLang_heavy', trust_remote_code=True)

In [18]:
def get_embeddings_heavy(df, model=model_heavy, tokenizer=tokenizer_heavy):
  X = []
  for index, row in df.iterrows():

    if pd.isna(row['start_index_CDRH3']):
      start = 1
      end = len(seq) + 1
    else:
      start = row["start_index_CDRH3"] + 1
      end = row["stop_index_CDRH3"] + 1

    seq = ' '.join(row['VHorVHH'][0:max_len])
    encoded_input = tokenizer(seq, return_tensors='pt')
    with torch.no_grad():
      model_output = model(**encoded_input).last_hidden_state


    model_output_sliced = model_output[:, int(start):int(end), :]

    embedding = model_output_sliced.mean(dim=1)
    X.append(embedding.squeeze())

  return torch.stack(X).numpy()

In [19]:
# Run only if you want to calculate embeddings
X_heavy_lit_test = get_embeddings_heavy(literature_test_df)
X_heavy_lit_train = get_embeddings_heavy(literature_train_df)

X_heavey_exp_train = get_embeddings_heavy(experiment_train_df)
X_heavey_exp_test = get_embeddings_heavy(experiment_test_df)

In [22]:
with open('X_heavy_lit_test.pkl', 'wb') as f:
    pickle.dump(X_heavy_lit_test, f)

with open('X_heavy_lit_train.pkl', 'wb') as f:
    pickle.dump(X_heavy_lit_train, f)

with open('X_heavey_exp_train.pkl', 'wb') as f:
    pickle.dump(X_heavey_exp_train, f)

with open('X_heavey_exp_test.pkl', 'wb') as f:
    pickle.dump(X_heavey_exp_test, f)

In [8]:
with open('X_heavy_lit_test.pkl', 'rb') as f:
    X_heavy_lit_test = pickle.load(f)

with open('X_heavy_lit_train.pkl', 'rb') as f:
    X_heavy_lit_train = pickle.load(f)

with open('X_heavey_exp_train.pkl', 'rb') as f:
    X_heavey_exp_train = pickle.load(f)

with open('X_heavey_exp_test.pkl', 'rb') as f:
    X_heavey_exp_test = pickle.load(f)