# Training the ESM

In [2]:
!pip install fair-esm



In [3]:
#!/usr/bin/env python
# coding: utf-8
# Imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.nn.functional as F  # All functions that don't have any parameters
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
from sklearn import metrics
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, roc_curve, matthews_corrcoef
from google.colab import drive
import os
import torch
import torch.utils.data
import esm


In [4]:
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [5]:
os.chdir("/content/drive/MyDrive/hackathon_data_scripts/")

## Load data

In [6]:
data_list = []
target_list = []

import glob

for fp in glob.glob("../hackathon_data_scripts/data/train/*input.npz"):
    data = np.load(fp)["arr_0"]
    targets = np.load(fp.replace("input", "labels"))["arr_0"]

    data_list.append(data)
    target_list.append(targets)
# print(data_list)

# Note:
# Choose your own training and val set based on data_list and target_list
# Here using the last partition as val set

X_train = np.concatenate(data_list[:-1])
y_train = np.concatenate(target_list[:-1])
nsamples, nx, ny = X_train.shape
print("Training set shape:", nsamples, nx, ny)

X_val = np.concatenate(data_list[-1:])
y_val = np.concatenate(target_list[-1:])
nsamples, nx, ny = X_val.shape
print("val set shape:", nsamples, nx, ny)

p_neg = len(y_train[y_train == 1]) / len(y_train) * 100
print("Percent positive samples in train:", p_neg)

p_pos = len(y_val[y_val == 1]) / len(y_val) * 100
print("Percent positive samples in val:", p_pos)

# make the data set into one dataset that can go into dataloader
train_ds = []
for i in range(len(X_train)):
    train_ds.append([np.transpose(X_train[i]), y_train[i]])

val_ds = []
for i in range(len(X_val)):
    val_ds.append([np.transpose(X_val[i]), y_val[i]])

bat_size = 64
print("\nNOTE:\nSetting batch-size to", bat_size)
train_ldr = torch.utils.data.DataLoader(train_ds, batch_size=bat_size, shuffle=True)
val_ldr = torch.utils.data.DataLoader(val_ds, batch_size=bat_size, shuffle=True)


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device (CPU/GPU):", device)
# device = torch.device("cpu")

Training set shape: 4174 420 54
val set shape: 1532 420 54
Percent positive samples in train: 24.96406324868232
Percent positive samples in val: 25.0

NOTE:
Setting batch-size to 64
Using device (CPU/GPU): cpu


In [7]:
pd.read_csv("../hackathon_data_scripts/data/example.csv")

Unnamed: 0,A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y,per_res_fa_atr,per_res_fa_rep,per_res_fa_sol,per_res_fa_elec,per_res_fa_dun,per_res_p_aa_pp,per_res_score,foldx_MP,foldx_MA,foldx_MB,foldx_PA,foldx_PB,foldx_AB,global_complex_total_score,global_complex_fa_atr,global_complex_fa_dun,global_complex_fa_elec,global_complex_fa_rep,global_complex_fa_sol,global_complex_p_aa_pp,global_tcr_total_score,global_tcr_fa_atr,global_tcr_fa_dun,global_tcr_fa_elec,global_tcr_fa_rep,global_tcr_fa_sol,global_tcr_p_aa_pp,global_pmhc_total_score,global_pmhc_fa_atr,global_pmhc_fa_dun,global_pmhc_fa_elec,global_pmhc_fa_rep,global_pmhc_fa_sol,global_pmhc_p_aa_pp
0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-1.893,0.119,2.162,-0.315,0.000,0.000,0.560,-5.72761,-0.922174,-0.922063,-0.00758,0.0,-1.31529,-1144.062,526.035,-727.694,4.939,1458.171,-90.02,2.385,-593.753,285.323,-350.705,2.964,681.926,-40.471,2.401,-555.661,246.945,-359.301,2.25,733.179,-39.727,0.481
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,-2.941,0.701,2.150,-1.024,0.133,-0.357,-2.177,-5.72761,-0.922174,-0.922063,-0.00758,0.0,-1.31529,-1144.062,526.035,-727.694,4.939,1458.171,-90.02,2.385,-593.753,285.323,-350.705,2.964,681.926,-40.471,2.401,-555.661,246.945,-359.301,2.25,733.179,-39.727,0.481
2,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-7.574,0.943,5.372,-3.233,1.531,-0.258,-4.839,-5.72761,-0.922174,-0.922063,-0.00758,0.0,-1.31529,-1144.062,526.035,-727.694,4.939,1458.171,-90.02,2.385,-593.753,285.323,-350.705,2.964,681.926,-40.471,2.401,-555.661,246.945,-359.301,2.25,733.179,-39.727,0.481
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,-5.121,0.398,4.949,-3.739,0.560,0.108,-3.345,-5.72761,-0.922174,-0.922063,-0.00758,0.0,-1.31529,-1144.062,526.035,-727.694,4.939,1458.171,-90.02,2.385,-593.753,285.323,-350.705,2.964,681.926,-40.471,2.401,-555.661,246.945,-359.301,2.25,733.179,-39.727,0.481
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-8.474,1.257,3.211,-1.961,1.772,0.234,-2.829,-5.72761,-0.922174,-0.922063,-0.00758,0.0,-1.31529,-1144.062,526.035,-727.694,4.939,1458.171,-90.02,2.385,-593.753,285.323,-350.705,2.964,681.926,-40.471,2.401,-555.661,246.945,-359.301,2.25,733.179,-39.727,0.481
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
415,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.000000,0.000000,0.00000,0.0,0.00000,0.000,0.000,0.000,0.000,0.000,0.00,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00,0.000,0.000,0.000
416,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.000000,0.000000,0.00000,0.0,0.00000,0.000,0.000,0.000,0.000,0.000,0.00,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00,0.000,0.000,0.000
417,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.000000,0.000000,0.00000,0.0,0.00000,0.000,0.000,0.000,0.000,0.000,0.00,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00,0.000,0.000,0.000
418,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.000000,0.000000,0.00000,0.0,0.00000,0.000,0.000,0.000,0.000,0.000,0.00,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00,0.000,0.000,0.000


In [8]:
def copy_as_dataframes(dataset_X):
    """
    Returns list of DataFrames with named features from dataset_X,
    using example CSV file
    """
    df_raw = pd.read_csv("../hackathon_data_scripts/data/example.csv")
    return [pd.DataFrame(arr, columns=df_raw.columns) for arr in dataset_X]


named_dataframes = copy_as_dataframes(X_train)
print(
    "Showing first complex as dataframe. Columns are positions and indices are calculated features"
)
named_dataframes[0]

Showing first complex as dataframe. Columns are positions and indices are calculated features


Unnamed: 0,A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y,per_res_fa_atr,per_res_fa_rep,per_res_fa_sol,per_res_fa_elec,per_res_fa_dun,per_res_p_aa_pp,per_res_score,foldx_MP,foldx_MA,foldx_MB,foldx_PA,foldx_PB,foldx_AB,global_complex_total_score,global_complex_fa_atr,global_complex_fa_dun,global_complex_fa_elec,global_complex_fa_rep,global_complex_fa_sol,global_complex_p_aa_pp,global_tcr_total_score,global_tcr_fa_atr,global_tcr_fa_dun,global_tcr_fa_elec,global_tcr_fa_rep,global_tcr_fa_sol,global_tcr_p_aa_pp,global_pmhc_total_score,global_pmhc_fa_atr,global_pmhc_fa_dun,global_pmhc_fa_elec,global_pmhc_fa_rep,global_pmhc_fa_sol,global_pmhc_p_aa_pp
0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-1.691,0.095,1.968,-0.287,0.000,0.000,0.833,-5.31471,-0.01875,-1.57906,0.0,-1.11466,-1.22179,-1198.017,489.407,-716.398,4.547,1444.857,-87.484,2.265,-630.983,252.669,-355.891,2.352,680.346,-49.781,2.017,-564.837,235.9,-369.03,2.223,730.223,-36.813,0.562
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,-2.559,0.106,1.995,-0.806,0.077,-0.431,-2.419,-5.31471,-0.01875,-1.57906,0.0,-1.11466,-1.22179,-1198.017,489.407,-716.398,4.547,1444.857,-87.484,2.265,-630.983,252.669,-355.891,2.352,680.346,-49.781,2.017,-564.837,235.9,-369.03,2.223,730.223,-36.813,0.562
2,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-7.444,0.565,5.379,-3.654,1.607,-0.254,-6.032,-5.31471,-0.01875,-1.57906,0.0,-1.11466,-1.22179,-1198.017,489.407,-716.398,4.547,1444.857,-87.484,2.265,-630.983,252.669,-355.891,2.352,680.346,-49.781,2.017,-564.837,235.9,-369.03,2.223,730.223,-36.813,0.562
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,-4.058,0.197,3.959,-2.659,0.216,-0.488,-4.343,-5.31471,-0.01875,-1.57906,0.0,-1.11466,-1.22179,-1198.017,489.407,-716.398,4.547,1444.857,-87.484,2.265,-630.983,252.669,-355.891,2.352,680.346,-49.781,2.017,-564.837,235.9,-369.03,2.223,730.223,-36.813,0.562
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-9.072,1.205,2.808,-1.844,2.182,0.060,-4.166,-5.31471,-0.01875,-1.57906,0.0,-1.11466,-1.22179,-1198.017,489.407,-716.398,4.547,1444.857,-87.484,2.265,-630.983,252.669,-355.891,2.352,680.346,-49.781,2.017,-564.837,235.9,-369.03,2.223,730.223,-36.813,0.562
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
415,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.00000,0.00000,0.0,0.00000,0.00000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.0,0.00,0.000,0.000,0.000,0.000
416,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.00000,0.00000,0.0,0.00000,0.00000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.0,0.00,0.000,0.000,0.000,0.000
417,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.00000,0.00000,0.0,0.00000,0.00000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.0,0.00,0.000,0.000,0.000,0.000
418,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.00000,0.00000,0.00000,0.0,0.00000,0.00000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.0,0.00,0.000,0.000,0.000,0.000


# View complex MHC, peptide and TCR alpha/beta sequences
You may want to view the one-hot encoded sequences as sequences in single-letter amino-acid format. The below function will return the TCR, peptide and MHC sequences for the dataset as 3 lists.

In [9]:
def oneHot(residue):
    """
    Converts string sequence to one-hot encoding
    Example usage:
    seq = "GSHSMRY"
    oneHot(seq)
    """

    mapping = dict(zip("ACDEFGHIKLMNPQRSTVWY", range(20)))
    if residue in "ACDEFGHIKLMNPQRSTVWY":
        return np.eye(20)[mapping[residue]]
    else:
        return np.zeros(20)


def reverseOneHot(encoding):
    """
    Converts one-hot encoded array back to string sequence
    """
    mapping = dict(zip(range(20), "ACDEFGHIKLMNPQRSTVWY"))
    seq = ""
    for i in range(len(encoding)):
        if np.max(encoding[i]) > 0:
            seq += mapping[np.argmax(encoding[i])]
    return seq


def extract_sequences(dataset_X):
    """
    Return DataFrame with MHC, peptide and TCR a/b sequences from
    one-hot encoded complex sequences in dataset X
    """
    mhc_sequences = [reverseOneHot(arr[0:179, 0:20]) for arr in dataset_X]
    pep_sequences = [reverseOneHot(arr[179:190, 0:20]) for arr in dataset_X]
    tcr_sequences = [reverseOneHot(arr[192:, 0:20]) for arr in dataset_X]
    df_sequences = pd.DataFrame(
        {"MHC": mhc_sequences, "peptide": pep_sequences, "tcr": tcr_sequences}
    )
    return df_sequences

In [10]:
complex_sequences = extract_sequences(X_val)
print("Showing MHC, peptide and TCR alpha/beta sequences for each complex")
complex_sequences

Showing MHC, peptide and TCR alpha/beta sequences for each complex


Unnamed: 0,MHC,peptide,tcr
0,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GLCTLVAML,VEQHPSTLSVQEGDSAVIKCTYSDSASNYFPWYKQELGKGPQLIID...
1,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,TQTQPGMFVQEKEAVTLDCTYDTSDPSYGLFWYKQPSSGEMIFLIY...
2,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GLCTLVAML,QVEQSPQSLIILEGKNCTLQCNYTVSPFSNLRWYKQDTGRGPVSLT...
3,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,QVEQSPPDLILQEGANSTLRCNFSDSVNNLQWFHQNPWGQLINLFY...
4,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,QPSTVASSEGAVVEIFCNHSVSNAYNFFWYLHFPGCAPRLLVKGSK...
...,...,...,...
1527,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,NLVPMVATV,EQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVV...
1528,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,NLVPMVATV,QSVTQLGSHVSVSEGALVLLRCNYSSSVPPYLFWYVQYPNQGLQLL...
1529,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GLCTLVAML,VQEGEDFTTYCNSSTTLSNIQWYKQRPGGHPVFLIQLVKSGEVKKK...
1530,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GLCTLVAML,VQEGEDFTTYCNSSTTLSNIQWYKQRPGGHPVFLIQLVKSGEVKKK...


In [11]:
MHC_list = np.array(complex_sequences["MHC"], dtype=str)
unique_mhc, counts_mhc = np.unique(MHC_list, return_counts=True)
print(np.asarray((unique_mhc, counts_mhc)).T)

[['GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDGETRKVKAHSQTHRVDLGTLRGYYNQSEAGSHTVQRMYGCDVGSDWRFLRGYHQYAYDGKDYIALKEDLRSWTAADMAAQTTKHKWEAAHVAEQLRAYLEGTCVEWLRRYLENGKETL'
  '1532']]


In [12]:
peptide_list = np.array(complex_sequences["peptide"], dtype=str)
unique_peptide, counts_peptide = np.unique(peptide_list, return_counts=True)
print(np.asarray((unique_peptide, counts_peptide)).T)

[['CLGGLLTMV' '2']
 ['FLYALALLL' '31']
 ['GILGFVFTL' '889']
 ['GLCTLVAML' '256']
 ['IMDQVPFSV' '19']
 ['KTWGQYWQV' '12']
 ['LLFGYPVYV' '29']
 ['MLDLQPETT' '16']
 ['NLVPMVATV' '204']
 ['RMFPNAPYL' '13']
 ['RTLNAWVKV' '38']
 ['SLFNTVATL' '9']
 ['SLLMWITQV' '8']
 ['YLLEMLWRL' '6']]


In [13]:
tcr_list = np.array(complex_sequences["tcr"], dtype=str)
unique_tcr, counts_tcr = np.unique(tcr_list, return_counts=True)
print(np.asarray((unique_tcr, counts_tcr)).T)

[['ALNIQEGKTATLTCNYTNYSPAYLQWYRQDPGRGPVFLLLIRENEKEKRKERLKVTFDTTLKQSLFHITASQPADSATYLCALDENNNARLMFGDGTQLVVKPVSQHPSWVICKSGTSVKIECRSLDFQATTMFWYRQFPKQSLMLMATSNEGSKATYEQGVEKDKFLINHASLTLSTLTVTSAHPEDSSFYICSARDFQGDEQYFGPGT'
  '1']
 ['ALNIQEGKTATLTCNYTNYSPAYLQWYRQDPGRGPVFLLLIRENEKEKRKERLKVTFDTTLKQSLFHITASQPADSATYLCALDGLMDSNYQLIWGAGTKLIIKPAGVIQSPRHEVTEMGQEVTLRCKPISGHDYLFWYRQTMMRGLELLIYFNNNVPIDDSGMPEDRFSAKMPNASFSTLKIQPSEPRDSAVYFCASSFRGGAGNEQFFGPGTR'
  '1']
 ['ALNIQEGKTATLTCNYTNYSPAYLQWYRQDPGRGPVFLLLIRENEKEKRKERLKVTFDTTLKQSLFHITASQPADSATYLCALDIIGPSGTYKYIFGTGTRLKVAGVTQTPKFRVLKTGQSMTLLCAQDMNHEYMYWYRQDPGMGLRLIHYSVGEGTTAKGEVPDGYNVSRLKKQNFLLGLESAAPSQTSVYFCASKPGPTYEQYFGPGT'
  '1']
 ...
 ['VQEGEDFTTYCNSSTTLSNIQWYKQRPGGHPVFLIQLVKSGEVKKKRLTFQFGEAKKNSSLHITATQTTDVGTYFCVDSNYQLIWGAGTKLIIKPGVTQTPRHLVMGMTNKKSLKCEQHLGHNAMYWYKQSAKKPLELMFVYSLEERVENNSVPSRFSPECPNSSHLFLHLHTLQPEDSALYLCASSQRLVGEGTEAFFGQGT'
  '1']
 ['VTQSQPEMSVQEAETVTLSCTYDTSENNYYLFWYKQPPSRQMILVIRQEAYKQQNATENRFSVNFQKAAKSFSLKISDSQLGDTAMYFCAFMRDYGGATNKLIFGT

## Train for MHC (one common sequence)

In [None]:
# Load ESM-1b model
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()

In [None]:
# Prepare data
data = [
    ("protein1", unique_mhc[0]),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations_mhc = results["representations"][33]

## Train for peptide (14 sequences)

In [None]:
# Prepare data
data = []
for i in range(len(unique_peptide)):
  data.append(str(i), )
data = [
    ("protein1", ),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations_peptide = results["representations"][33]

In [None]:
# Here do the same for thc

In [None]:
# Use pickle to save the matrices