In [1]:
# from datetime import datetime
from IPython.display import display, Markdown
from datetime import datetime, date

todays_date = str(datetime.now().date())

display(Markdown(f'# 002 Notebook - Extract ESM-2 CDR embeddings'))
display(Markdown(f'Author: Lena Erlach'))
display(Markdown(f'Created: 2024-01-18'))
display(Markdown(f'Last modified: {todays_date}'))

# 002 Notebook - Extract ESM-2 CDR embeddings

Author: Lena Erlach

Created: 2024-01-18

Last modified: 2024-07-17

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os, configparser
from tqdm import tqdm
import sys
import pickle

# Load the config file
CONFIG_PATH = {}
CONFIG_PATH['OVA'] = '../../config_file.txt'
CONFIG_PATH['RBD'] = '../../config_file_RBD.txt'
config_O = configparser.ConfigParser()
config_R = configparser.ConfigParser()
config_O.read(CONFIG_PATH['OVA'])
config_R.read(CONFIG_PATH['RBD'])
ROOT_DIR = config_O['ROOT']['ROOT_DIR']


sys.path.append(os.path.join(ROOT_DIR, 'src'))
sys.path.append(os.path.join(ROOT_DIR, 'src/Spec_classification'))
import utils_nb as utils
import Load_embs_class as lec

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


## Table of Contents

- [1. Load datasets & preprocess](#Preprocess)
- [2. Extract CDR embeddings from ESM-2](#Extract-embeddings)
- [2.1 Test loading CDR embeddings](#Load-embeddings)
- [3. PCA Plot of embeddings](#PCA-embeddings)

#### Define paths

In [3]:
# temp folder for the anarci files
anarci_dir = os.path.join(ROOT_DIR, config_O['PATHS']['ANARCI_DIR'])

# output paths for CDR embeddings
emb_outputPath_ESM_CDR = {'OVA': os.path.join(ROOT_DIR, config_O['VH_VL_EMBEDPATH']['ESM2_CDRextract']), 'RBD': os.path.join(ROOT_DIR, config_R['VH_VL_EMBEDPATH']['ESM2_CDRextract'])}
emb_outputPath_ESM_CDR_H = {'OVA': os.path.join(ROOT_DIR, config_O['VH_EMBEDPATH']['ESM2_CDRextract']), 'RBD': os.path.join(ROOT_DIR, config_R['VH_EMBEDPATH']['ESM2_CDRextract'])}

# Seq column name 'VDJ_aaSeq', 'VDJ_aaSeqCDR3', 'cdr_comb'...
seq_col_H = 'VDJ_aaSeq'
seq_col_L = 'VJ_aaSeq'


# use anarci numbering
chothia_nums2 = {'H': [[24, 34], [50, 58], [94, 103]], 
                     'L': [[24, 34], [48, 54], [89, 98]]}
scheme='chothia'


<a id='Preprocess'></a>
## 1. Load datasets of sequences, ESM embedding and extract CDR embedding - VH_VL

In [10]:
######## LOAD ESM EMBEDDINGS ########
ESM_var_embeddings = {}
seqs_H = {}
seqs_L = {}
names = {}

embeddings_esm_CDR_fl = {}

for dataset in ['OVA', 'RBD']: 

    f_192 = False

    if dataset == 'OVA':
        f_192 = True

    Embeddings =  lec.LoadEmbeddings_VH_VL(CONFIG_PATH[dataset], seq_col='VDJ_VJ_aaSeq', filter_192 = f_192)
    Embeddings.load_embeddings(embedding_type = 'esm')


    ### Load mAb sequences
    seq_df = Embeddings.seq_df
    seq_df = seq_df[seq_df['seq_complete'] == True]
    names[dataset] = Embeddings.names
    # seqs[dataset] = Embeddings.seqs

    seqs_H[dataset] = seq_df[seq_col_H].reset_index(drop=True)
    seqs_L[dataset] = seq_df[seq_col_L].reset_index(drop=True)

    ## Load embeddings - ESM2 - VH_VL
    ESM_var_embeddings[dataset] = Embeddings.emb_ESM_var
    print(f"ESM - embeddings loaded for {dataset}") 


    # get CDR extract embedding for heavy and light chain sequences 
    embeddings_esm_CDR_fl[dataset] = utils.extract_cdr_embedding(seqs_H[dataset][:3], seqs_L[dataset][:3], names=names[dataset][:3], embeddings=ESM_var_embeddings[dataset][:3], 
                                                        save_embeddings=True, emb_outputPath=emb_outputPath_ESM_CDR[dataset],
                                                        anarci_dir=anarci_dir, all_regions = chothia_nums2)

    print(f"ESM - CDR embeddings extracted for {dataset}")




Load embeddings for VH_VL sequences
ESM - VH_VL embeddings loaded
ESM - embeddings loaded for OVA
paired


100%|██████████| 3/3 [00:03<00:00,  1.04s/it]


ESM - CDR embeddings extracted for OVA
Load embeddings for VH_VL sequences
ESM - VH_VL embeddings loaded
ESM - embeddings loaded for RBD
paired


100%|██████████| 3/3 [00:02<00:00,  1.17it/s]

ESM - CDR embeddings extracted for RBD





<a id='Extract-embeddings'></a>
## 2. Load datasets of sequences, ESM embedding and extract CDR embedding - VH


In [14]:
embeddings_esm_CDR_fl_H = {}

for dataset in ['OVA', 'RBD']: 

    f_192 = False

    if dataset == 'OVA':
        f_192 = True

    Embeddings =  lec.LoadEmbeddings_VH_VL(CONFIG_PATH[dataset], seq_col='VDJ_aaSeq', filter_VH_complete = True, filter_192 = f_192)
    Embeddings.load_embeddings(embedding_type = 'esm')

    ### Load mAb sequences
    seq_df = Embeddings.seq_df
    seq_df = seq_df[seq_df['seq_complete'] == True]
    names[dataset] = Embeddings.names


    seqs_H[dataset] = seq_df[seq_col_H].reset_index(drop=True)
    seqs_L[dataset] = []

    ## Load embeddings - ESM2 - VH_VL
    ESM_var_embeddings[dataset] = Embeddings.emb_ESM_var
    print(f"ESM - embeddings loaded for {dataset}") 


    # get CDR extract embedding for heavy only chain sequences
    embeddings_esm_CDR_fl_H[dataset] = utils.extract_cdr_embedding(seqs_H[dataset][:3], seqs_L[dataset], names=names[dataset][:3], embeddings=ESM_var_embeddings[dataset][:3], 
                                                        save_embeddings=True, emb_outputPath=emb_outputPath_ESM_CDR_H[dataset],
                                                        paired_chains=False, anarci_dir=anarci_dir, all_regions = chothia_nums2)



else


100%|██████████| 3/3 [00:01<00:00,  2.29it/s]
