In [2]:
from datetime import datetime
from IPython.display import display, Markdown

todays_date = str(datetime.now().date())

display(
    Markdown(
        "# 2. Notebook - Generate ESM variable length embeddings of OVA and RBD antibody sequences"
    )
)
display(Markdown("Author: Lena Erlach"))
display(Markdown("Created: 2024-01-17"))
display(Markdown(f"Last modified: {todays_date}"))

# 2. Notebook - Generate ESM variable length embeddings of OVA and RBD antibody sequences

Author: Lena Erlach

Created: 2024-01-17

Last modified: 2024-07-17

## Table of Contents

- [Load datasets & preprocess](#Preprocess)
- [Generate distance matrix](#Generate-dist-mat)
- [Generate embeddings](#Generate-embeddings)
    - [1. Embed VH_VL sequences](#EmbedVH_VL)

In [4]:
import pandas as pd
import abmap
import torch
import os
import sys
import configparser


# add root directory to path such that the utils_nb file can be imported
CONFIG_PATH_OVA = "../../config_file.txt"
CONFIG_PATH_RBD = "../../config_file_RBD.txt"

# ROOT dir
# setup parser for the config file
config_O = configparser.ConfigParser()
config_R = configparser.ConfigParser()
config_O.read(CONFIG_PATH_OVA)
config_R.read(CONFIG_PATH_RBD)
ROOT_DIR = config_O["ROOT"]["ROOT_DIR"]

UTILS_DIR = os.path.join(ROOT_DIR, "src")
sys.path.append(UTILS_DIR)
import utils_nb as utils

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


### Define variables

1. Sequence region to model: "VDJ_VJ_aaSeq" Heavy and light chain paired sequences
2. Dataset input path 

### Load dataset and calculate dist matrix

In [5]:
seq_col = "VDJ_VJ_aaSeq"  # column name of the sequence to filter for (VDJ_VJ_aaSeq, VDJ_aaSeq, ...)

seq_df_inputPath = {
    "OVA": os.path.join(ROOT_DIR, config_O["PATHS"]["SEQ_DF"]),
    "RBD": os.path.join(ROOT_DIR, config_R["PATHS"]["SEQ_DF"]),
}
dist_matrix_outputPath_VH_VL = {
    "OVA": os.path.join(ROOT_DIR, config_O["VH_VL_EMBEDPATH"]["DISTANCE_MATRIX"]),
    "RBD": os.path.join(ROOT_DIR, config_R["VH_VL_EMBEDPATH"]["DISTANCE_MATRIX"]),
}
dist_matrix_outputPath_VH = {
    "OVA": os.path.join(ROOT_DIR, config_O["VH_EMBEDPATH"]["DISTANCE_MATRIX"]),
    "RBD": os.path.join(ROOT_DIR, config_R["VH_EMBEDPATH"]["DISTANCE_MATRIX"]),
}

cuda_dev_num = 0


# embedding paths for VH_VL embeddings
out_folder = {
    "OVA": os.path.join(ROOT_DIR, "data/processed/embeddings/OVA/"),
    "RBD": os.path.join(ROOT_DIR, "data/processed/embeddings/RBD/"),
}

seq_df = {}
seqs = {}
seqs_VH_VL = {}
names = {}

for dataset in ["OVA", "RBD"]:
    # load preprocessed dataframe
    s_df = pd.read_csv(seq_df_inputPath[dataset])
    seq_df[dataset] = s_df[
        s_df.seq_complete == True
    ]  # filter for complete seqs, just in case
    # get indeces/names and sequences 2 lists
    names[dataset] = seq_df[dataset].seq_id.tolist()
    seqs[dataset] = seq_df[dataset].VDJ_aaSeq.tolist()
    seqs_VH_VL[dataset] = seq_df[dataset].VDJ_VJ_aaSeq.tolist()

    # # Calculate distance matrix
    # for s, path in zip([seqs_VH_VL[dataset], seqs[dataset]], [dist_matrix_outputPath_VH_VL[dataset], dist_matrix_outputPath_VH[dataset]]):
    #     distance_matrix = utils.calc_norm_levens_dist(s)
    #     np.savetxt(path, distance_matrix, delimiter=',')

In [6]:
if torch.cuda.is_available():
    dev = "cuda:{}".format(cuda_dev_num)
else:
    dev = "cpu"
device = torch.device(dev)
torch.cuda.set_device(device)

In [7]:
# Using ESM2 (best for functional prediction, e.g. affinity, paratope prediction, etc.)
pretrained_path_H = (
    "/data/cb/scratch/lenae/p-GP-LLM-AbPred/models/pretrained/AbMAP_esm2_H.pt"
)
pretrained_path_L = (
    "/data/cb/scratch/lenae/p-GP-LLM-AbPred/models/pretrained/AbMAP_esm2_L.pt"
)

abmap_H = abmap.load_abmap(
    pretrained_path=pretrained_path_H, device=cuda_dev_num, plm_name="esm2"
)
abmap_L = abmap.load_abmap(
    pretrained_path=pretrained_path_L, device=cuda_dev_num, plm_name="esm2"
)
pretrained_path_H_ls = [pretrained_path_H, pretrained_path_L]

esm2 loaded to cuda:0


Using cache found in /afs/csail.mit.edu/u/l/lenae011/.cache/torch/hub/facebookresearch_esm_main


Loaded the Pre-trained Model!
esm2 loaded to cuda:0


Using cache found in /afs/csail.mit.edu/u/l/lenae011/.cache/torch/hub/facebookresearch_esm_main


Loaded the Pre-trained Model!


This step might take a few minutes to load the model on the GPU...

<a id='EmbedVH_VL'></a> 
# 1. Embed VH_VL and VH sequences 

<a id='EmbedVH_VL-esm'></a>
Pass sequences through foundational PLM (ESM-2)

This step takes time! 


In [10]:
model_typ = "esm2"


embedding_list, ids_to_drop = {}, {}
embedding_list_VH, ids_to_drop_VH = {}, {}

for dataset in ["OVA", "RBD"]:
    df = seq_df[dataset]

    # Generate the ESM embeddings
    ids_to_drop[dataset] = []
    emb_ids, embedding_list[dataset], ids_to_drop[dataset] = (
        utils.generate_ESM_embedding(
            df,
            seq_column_HC="VDJ_aaSeq",
            seq_column_LC="VJ_aaSeq",
            augment=False,
            model_typ="esm2",
            out_folder=out_folder[dataset],
            save_plm=True,
            save_PLM_aug=False,
            cuda_dev_num=dev,
        )
    )

    print(f"ESM VH_VL {dataset} embeddings done!!")

    # Generate the ESM embeddings for VH only
    ids_to_drop_VH[dataset] = []
    emb_ids, embedding_list_VH[dataset], ids_to_drop_VH[dataset] = (
        utils.generate_ESM_embedding(
            df,
            seq_column_HC="VDJ_aaSeq",
            seq_column_LC="VJ_aaSeq",
            augment=False,
            model_typ="esm2",
            VH_only=True,
            out_folder=out_folder[dataset],
            save_plm=True,
            save_PLM_aug=False,
            cuda_dev_num=dev,
        )
    )
    print(f"ESM VH {dataset} embeddings done!!")

VH_VL


  0%|          | 0/3622 [00:00<?, ?it/s]

[2024-07-17-09:11:39] Saving 0_s1 H




[2024-07-17-09:11:40] Saving 0_s1 L


  0%|          | 1/3622 [00:02<2:57:46,  2.95s/it]

[2024-07-17-09:11:41] Saving 1_s1 H




[2024-07-17-09:11:43] Saving 1_s1 L


  0%|          | 2/3622 [00:06<3:24:22,  3.39s/it]

[2024-07-17-09:11:45] Saving 2_s1 H




[2024-07-17-09:11:47] Saving 2_s1 L


  0%|          | 3/3622 [00:10<3:50:41,  3.82s/it]

[2024-07-17-09:11:49] Saving 3_s1 H




[2024-07-17-09:11:51] Saving 3_s1 L


: 

: 

In [None]:
# # Load augmented embeddings HC embeddings
# emb_inputPath_ESM_aug = os.path.join(ROOT_DIR, config['VH_EMBEDPATH']['ESM2_aug_100_var'])

# ### ESM2 augmented - VH_VL
# aug_esm_embeddings = utils.load_pickle_embeddings_VH_VL(names, emb_inputPath_ESM_aug, file_suffix= '', embedding_type = 'var')
# # aug_esm_embeddings_VH = utils.load_pickle_embeddings(names, emb_inputPath_ESM_aug, file_suffix = '_H')
