# Training the ESM

In [1]:
!pip install fair-esm



In [2]:
#!/usr/bin/env python
# coding: utf-8
# Imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.nn.functional as F  # All functions that don't have any parameters
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
from sklearn import metrics
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, roc_curve, matthews_corrcoef
# from google.colab import drive
import os
import torch
import torch.utils.data
import esm


In [3]:
#drive.mount('/content/drive/')

In [4]:
#os.chdir("/content/drive/MyDrive/hackathon_data_scripts/")

## Functions for extracting sequences

In [5]:
def reverseOneHot(encoding):
    """
    Converts one-hot encoded array back to string sequence
    """
    mapping = dict(zip(range(20), "ACDEFGHIKLMNPQRSTVWY"))
    seq = ""
    for i in range(len(encoding)):
        if np.max(encoding[i]) > 0:
            seq += mapping[np.argmax(encoding[i])]
    return seq


def extract_sequences(dataset_X):
    """
    Return DataFrame with MHC, peptide and TCR a/b sequences from
    one-hot encoded complex sequences in dataset X
    """
    mhc_sequences = [reverseOneHot(arr[0:179, 0:20]) for arr in dataset_X]
    pep_sequences = [reverseOneHot(arr[179:190, 0:20]) for arr in dataset_X]
    tcr_sequences = [reverseOneHot(arr[192:, 0:20]) for arr in dataset_X]
    df_sequences = pd.DataFrame(
        {"MHC": mhc_sequences, "peptide": pep_sequences, "tcr": tcr_sequences}
    )
    return df_sequences

## Load data

In [6]:
data_list = []
target_list = []

import glob

for fp in glob.glob("../hackathon_data_scripts/data/train/*input.npz"):
    data = np.load(fp)["arr_0"]
    targets = np.load(fp.replace("input", "labels"))["arr_0"]

    data_list.append(data)
    target_list.append(targets)
# print(data_list)

# Note:
# Choose your own training and val set based on data_list and target_list
# Here using the last partition as val set

X_train = np.concatenate(data_list[:-1])
y_train = np.concatenate(target_list[:-1])

X_val = np.concatenate(data_list[-1:])
y_val = np.concatenate(target_list[-1:])


In [7]:
# Load embeddings
MHC_train, MHC_val, pep_train, pep_val, tcr_val, tcr_train = [], [], [], [], [], []
for fp in glob.glob("mean_embedding/MHC-X_train/*.pt"):
    MHC_train.append(torch.load(fp)['mean_representations'][33].numpy())

for fp in glob.glob("mean_embedding/MHC-X_val/*.pt"):
    MHC_val.append(torch.load(fp)['mean_representations'][33].numpy())

for fp in glob.glob("mean_embedding/peptide-X_train/*.pt"):
    pep_train.append(torch.load(fp)['mean_representations'][33].numpy())

for fp in glob.glob("mean_embedding/peptide-X_val/*.pt"):
    pep_val.append(torch.load(fp)['mean_representations'][33].numpy())

for fp in glob.glob("mean_embedding/tcr-X_train/*.pt"):
    tcr_train.append(torch.load(fp)['mean_representations'][33].numpy())

for fp in glob.glob("mean_embedding/tcr-X_val/*.pt"):
    tcr_val.append(torch.load(fp)['mean_representations'][33].numpy())



In [8]:
print(len(tcr_val), len(tcr_train))

1137 3084


## Putting in embeddings for the sequence

In [25]:
data = [X_val, X_train]


complex_sequences = extract_sequences(X_val)

MHC_list = np.array(complex_sequences["MHC"], dtype=str)
peptide_list = np.array(complex_sequences["peptide"], dtype=str)
tcr_list = np.array(complex_sequences["tcr"], dtype=str)

unique_mhc = np.unique(MHC_list)
unique_peptide = np.unique(peptide_list)
unique_tcr = np.unique(tcr_list)

map_mhc = {}
for i in range(len(MHC_list)):
    for j in range(len(unique_mhc)):
        if MHC_list[i] == unique_mhc[j]:
            map_mhc[i] = j
            break

map_peptide = {}
for i in range(len(peptide_list)):
    for j in range(len(unique_peptide)):
        if peptide_list[i] == unique_peptide[j]:
            map_peptide[i] = j
            break

map_tcr = {}
for i in range(len(tcr_list)):
    for j in range(len(unique_tcr)):
        if tcr_list[i] == unique_tcr[j]:
            map_tcr[i] = j
            break

    # Here we actually do the insertion of the embedding


for i in range(X_val.shape[0]):
    rows, _ = X_val[i].shape
    embed_arr = np.zeros((rows,1280))
    embed_arr[0,:] = MHC_val[map_mhc[i]]
    embed_arr[1,:] = pep_val[map_peptide[i]]
    embed_arr[2,:] = tcr_val[map_tcr[i]]
    np.concatenate((X_val[i], embed_arr), axis=1))

test[0]
        


IndexError: list index out of range

14

In [None]:
nsamples, nx, ny = X_train.shape
print("Training set shape:", nsamples, nx, ny)
nsamples, nx, ny = X_val.shape
print("val set shape:", nsamples, nx, ny)

p_neg = len(y_train[y_train == 1]) / len(y_train) * 100
print("Percent positive samples in train:", p_neg)

p_pos = len(y_val[y_val == 1]) / len(y_val) * 100
print("Percent positive samples in val:", p_pos)

# make the data set into one dataset that can go into dataloader
train_ds = []
for i in range(len(X_train)):
    train_ds.append([np.transpose(X_train[i]), y_train[i]])

val_ds = []
for i in range(len(X_val)):
    val_ds.append([np.transpose(X_val[i]), y_val[i]])

bat_size = 64
print("\nNOTE:\nSetting batch-size to", bat_size)
train_ldr = torch.utils.data.DataLoader(train_ds, batch_size=bat_size, shuffle=True)
val_ldr = torch.utils.data.DataLoader(val_ds, batch_size=bat_size, shuffle=True)


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device (CPU/GPU):", device)
# device = torch.device("cpu")

In [37]:
pd.read_csv("../hackathon_data_scripts/data/example.csv")


16

In [8]:
def copy_as_dataframes(dataset_X):
    """
    Returns list of DataFrames with named features from dataset_X,
    using example CSV file
    """
    df_raw = pd.read_csv("../hackathon_data_scripts/data/example.csv")
    return [pd.DataFrame(arr, columns=df_raw.columns) for arr in dataset_X]


named_dataframes = copy_as_dataframes(X_train)
print(
    "Showing first complex as dataframe. Columns are positions and indices are calculated features"
)
named_dataframes[0]

Showing first complex as dataframe. Columns are positions and indices are calculated features


Unnamed: 0,A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y,per_res_fa_atr,per_res_fa_rep,per_res_fa_sol,per_res_fa_elec,per_res_fa_dun,per_res_p_aa_pp,per_res_score,foldx_MP,foldx_MA,foldx_MB,foldx_PA,foldx_PB,foldx_AB,global_complex_total_score,global_complex_fa_atr,global_complex_fa_dun,global_complex_fa_elec,global_complex_fa_rep,global_complex_fa_sol,global_complex_p_aa_pp,global_tcr_total_score,global_tcr_fa_atr,global_tcr_fa_dun,global_tcr_fa_elec,global_tcr_fa_rep,global_tcr_fa_sol,global_tcr_p_aa_pp,global_pmhc_total_score,global_pmhc_fa_atr,global_pmhc_fa_dun,global_pmhc_fa_elec,global_pmhc_fa_rep,global_pmhc_fa_sol,global_pmhc_p_aa_pp
0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-1.691,0.095,1.968,-0.287,0.000,0.000,0.833,-5.31471,-0.01875,-1.57906,0.0,-1.11466,-1.22179,-1198.017,489.407,-716.398,4.547,1444.857,-87.484,2.265,-630.983,252.669,-355.891,2.352,680.346,-49.781,2.017,-564.837,235.9,-369.03,2.223,730.223,-36.813,0.562
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,-2.559,0.106,1.995,-0.806,0.077,-0.431,-2.419,-5.31471,-0.01875,-1.57906,0.0,-1.11466,-1.22179,-1198.017,489.407,-716.398,4.547,1444.857,-87.484,2.265,-630.983,252.669,-355.891,2.352,680.346,-49.781,2.017,-564.837,235.9,-369.03,2.223,730.223,-36.813,0.562
2,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-7.444,0.565,5.379,-3.654,1.607,-0.254,-6.032,-5.31471,-0.01875,-1.57906,0.0,-1.11466,-1.22179,-1198.017,489.407,-716.398,4.547,1444.857,-87.484,2.265,-630.983,252.669,-355.891,2.352,680.346,-49.781,2.017,-564.837,235.9,-369.03,2.223,730.223,-36.813,0.562
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,-4.058,0.197,3.959,-2.659,0.216,-0.488,-4.343,-5.31471,-0.01875,-1.57906,0.0,-1.11466,-1.22179,-1198.017,489.407,-716.398,4.547,1444.857,-87.484,2.265,-630.983,252.669,-355.891,2.352,680.346,-49.781,2.017,-564.837,235.9,-369.03,2.223,730.223,-36.813,0.562
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-9.072,1.205,2.808,-1.844,2.182,0.060,-4.166,-5.31471,-0.01875,-1.57906,0.0,-1.11466,-1.22179,-1198.017,489.407,-716.398,4.547,1444.857,-87.484,2.265,-630.983,252.669,-355.891,2.352,680.346,-49.781,2.017,-564.837,235.9,-369.03,2.223,730.223,-36.813,0.562
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
415,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.00000,0.00000,0.0,0.00000,0.00000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.0,0.00,0.000,0.000,0.000,0.000
416,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.00000,0.00000,0.0,0.00000,0.00000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.0,0.00,0.000,0.000,0.000,0.000
417,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.00000,0.00000,0.0,0.00000,0.00000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.0,0.00,0.000,0.000,0.000,0.000
418,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.00000,0.00000,0.0,0.00000,0.00000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.0,0.00,0.000,0.000,0.000,0.000


# View complex MHC, peptide and TCR alpha/beta sequences
You may want to view the one-hot encoded sequences as sequences in single-letter amino-acid format. The below function will return the TCR, peptide and MHC sequences for the dataset as 3 lists.

In [7]:
def oneHot(residue):
    """
    Converts string sequence to one-hot encoding
    Example usage:
    seq = "GSHSMRY"
    oneHot(seq)
    """

    mapping = dict(zip("ACDEFGHIKLMNPQRSTVWY", range(20)))
    if residue in "ACDEFGHIKLMNPQRSTVWY":
        return np.eye(20)[mapping[residue]]
    else:
        return np.zeros(20)


def reverseOneHot(encoding):
    """
    Converts one-hot encoded array back to string sequence
    """
    mapping = dict(zip(range(20), "ACDEFGHIKLMNPQRSTVWY"))
    seq = ""
    for i in range(len(encoding)):
        if np.max(encoding[i]) > 0:
            seq += mapping[np.argmax(encoding[i])]
    return seq


def extract_sequences(dataset_X):
    """
    Return DataFrame with MHC, peptide and TCR a/b sequences from
    one-hot encoded complex sequences in dataset X
    """
    mhc_sequences = [reverseOneHot(arr[0:179, 0:20]) for arr in dataset_X]
    pep_sequences = [reverseOneHot(arr[179:190, 0:20]) for arr in dataset_X]
    tcr_sequences = [reverseOneHot(arr[192:, 0:20]) for arr in dataset_X]
    df_sequences = pd.DataFrame(
        {"MHC": mhc_sequences, "peptide": pep_sequences, "tcr": tcr_sequences}
    )
    return df_sequences

In [8]:
complex_sequences = extract_sequences(X_val)
print("Showing MHC, peptide and TCR alpha/beta sequences for each complex")
complex_sequences

Showing MHC, peptide and TCR alpha/beta sequences for each complex


Unnamed: 0,MHC,peptide,tcr
0,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,EQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVV...
1,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,EQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVV...
2,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,QQVKQNSPSLSVQEGRISILNCDYTNSMFDYFLWYKKYPAEGPTFL...
3,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,EQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVV...
4,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,EQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVV...
...,...,...,...
1521,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,FLYALALLL,QSPQSMFIQEGEDVSMNCTSSSIFNTWLWYKQEPGEGPVLLIALYK...
1522,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,LLFGYPVYV,PQALSIQEGENATMNCSYKTSINNLQWYRQNSGRGLVHLILIRSNE...
1523,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GLCTLVAML,EQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVV...
1524,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,FLYALALLL,QQVKQNSPSLSVQEGRISILNCDYTNSMFDYFLWYKKYPAEGPTFL...


In [9]:
MHC_list = np.array(complex_sequences["MHC"], dtype=str)
unique_mhc, counts_mhc = np.unique(MHC_list, return_counts=True)
print(np.asarray((unique_mhc, counts_mhc)).T)

[['GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDGETRKVKAHSQTHRVDLGTLRGYYNQSEAGSHTVQRMYGCDVGSDWRFLRGYHQYAYDGKDYIALKEDLRSWTAADMAAQTTKHKWEAAHVAEQLRAYLEGTCVEWLRRYLENGKETL'
  '1526']]


In [10]:
peptide_list = np.array(complex_sequences["peptide"], dtype=str)
unique_peptide, counts_peptide = np.unique(peptide_list, return_counts=True)
print(np.asarray((unique_peptide, counts_peptide)).T)

[['FLYALALLL' '38']
 ['GILGFVFTL' '866']
 ['GLCTLVAML' '278']
 ['IMDQVPFSV' '6']
 ['KLQCVDLHV' '4']
 ['KTWGQYWQV' '7']
 ['KVAELVHFL' '5']
 ['KVLEYVIKV' '9']
 ['LLFGYPVYV' '55']
 ['MLDLQPETT' '16']
 ['NLVPMVATV' '172']
 ['RMFPNAPYL' '14']
 ['RTLNAWVKV' '33']
 ['SLFNTVATL' '5']
 ['SLLMWITQV' '5']
 ['YLLEMLWRL' '13']]


In [11]:
tcr_list = np.array(complex_sequences["tcr"], dtype=str)
unique_tcr, counts_tcr = np.unique(tcr_list, return_counts=True)
print(np.asarray((unique_tcr, counts_tcr)).T)

[['ALNIQEGKTATLTCNYTNYSPAYLQWYRQDPGRGPVFLLLIRENEKEKRKERLKVTFDTTLKQSLFHITASQPADSATYLCALDMGGGSQGNLIFGKGTKLSVKPGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQVPGQGLRLIYYSHIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRAADTQYFGPGT'
  '2']
 ['ALNIQEGKTATLTCNYTNYSPAYLQWYRQDPGRGPVFLLLIRENEKEKRKERLKVTFDTTLKQSLFHITASQPADSATYLCALDPDTDKLIFGTGTRLQVFPGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSMRASVEQFFGPGT'
  '2']
 ['ALNIQEGKTATLTCNYTNYSPAYLQWYRQDPGRGPVFLLLIRENEKEKRKERLKVTFDTTLKQSLFHITASQPADSATYLCALDPRGASKIIFGSGTRLSIRPDVTQTPRNRITKTGKRIMLECSQTKGHDRMYWYRQDPGLGLRLIYYSFDVKDINKGEISDGYSVSRQAQAKFSLSLESAIPNQTALYFCATSDTQGGGQPQHFGDGTR'
  '1']
 ...
 ['VQEGEDFTTYCNSSTTLSNIQWYKQRPGGHPVFLIQLVKSGEVKKKRLTFQFGEAKKNSSLHITATQTTDVGTYFCAGSYGGSQGNLIFGKGTKLSVKPGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSSRSHQPQHFGDGT'
  '2']
 ['VQEGEDFTTYCNSSTTLSNIQWYKQRPGGHPVFLIQLVKSGEVKKKRLTFQFGEAKKNSSLHITATQTTDVGTYFCALGSGNTGKLIFGQGTTLQVKPVAQSPRYKITEKSQ

## Train for MHC (one common sequence)

In [12]:
# Load ESM-1b model
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()

In [16]:
# Prepare data
data = [
    ("protein1", unique_mhc[0]),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations_mhc = results["representations"][33]

## Train for peptide (14 sequences)

In [26]:
# Prepare data
data = []
for i in range(len(unique_peptide)):
  data.append((str(i), unique_peptide[i]))

batch_labels, batch_strs, batch_tokens = batch_converter(data)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations_peptide = results["representations"][33]

In [32]:
# Here do the same for thc
data = []
for i in range(len(tcr_list)):
  data.append((str(i), tcr_list[i]))

batch_labels, batch_strs, batch_tokens = batch_converter(data)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations_peptide = results["representations"][33]

In [31]:
# Use pickle to save the matrices


(16, 16)