# Prepare dataset for training and training

First we will generate peptide encodings in the following way:
- Filter epitopes to 8 to 12 AA, so they can be encoded by TEIM autoencoder
- Encode epitopes using TEIM autoencoder.
- Scale the encoding to be between 0 and 1.
- Save epitopes and their encoding in a dataframe in `interim` folder:
    - `peptide`: the peptide/epitope amino acid sequence
    - `is_mono_allelic`: whether the epitope is presented by a single HLA allele (True, False)
    - `hla_allele`: the HLA allele or alleles that the epitope binds (name format: HLA-A-01-01)
    - `label`: Whether the peptide binds to the HLA allele (1: Binder, 0: Non-binder)
    - `peptide_encoding`: the encoding of the peptide. 
    - `norm_peptide_encoding`: normalized peptide encoding. 

Next we will process the HLA data
- Load HLA alleles and their encoding 


## Setup Autoencoder

Load the autoencoder model and the tokenizer as shown in `2.0-amdr-exploring-peptide-encoding.ipynb`.

The pretrained autoencoder will be saved as `epi_encoder` and will be used along the notebook.

In [1]:
# CODE FROM TEIM PAPER
#     Path on GitHub Repo:     TEIM/scripts/data_process.py

import torch
import torch.nn as nn
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    
class View(nn.Module):
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape
    def forward(self, input):
        shape = [input.shape[0]] + list(self.shape)
        return input.view(*shape)
    

class AutoEncoder(nn.Module):
    def __init__(self, 
        tokenizer,
        dim_hid,
        len_seq,
    ):
        super().__init__()
        embedding = tokenizer.embedding_mat()
        vocab_size, dim_emb = embedding.shape
        self.embedding_module = nn.Embedding.from_pretrained(torch.FloatTensor(embedding), padding_idx=0, )
        self.encoder = nn.Sequential(
            nn.Conv1d(dim_emb, dim_hid, 3, padding=1),
            nn.BatchNorm1d(dim_hid),
            nn.ReLU(),
            nn.Conv1d(dim_hid, dim_hid, 3, padding=1),
            nn.BatchNorm1d(dim_hid),
            nn.ReLU(),
        )

        self.seq2vec = nn.Sequential(
            nn.Flatten(),
            nn.Linear(len_seq * dim_hid, dim_hid),
            nn.ReLU()
        )
        self.vec2seq = nn.Sequential(
            nn.Linear(dim_hid, len_seq * dim_hid),
            nn.ReLU(),
            View(dim_hid, len_seq)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(dim_hid, dim_hid, kernel_size=3, padding=1),
            nn.BatchNorm1d(dim_hid),
            nn.ReLU(),
            nn.ConvTranspose1d(dim_hid, dim_hid, kernel_size=3, padding=1),
            nn.BatchNorm1d(dim_hid),
            nn.ReLU(),
        )
        self.out_layer = nn.Linear(dim_hid, vocab_size)

    def forward(self, inputs):
        inputs = inputs.long()
        seq_emb = self.embedding_module(inputs)
        
        seq_enc = self.encoder(seq_emb.transpose(1, 2))
        vec = self.seq2vec(seq_enc)
        seq_repr = self.vec2seq(vec)
        indices = None
        seq_dec = self.decoder(seq_repr)
        out = self.out_layer(seq_dec.transpose(1, 2))
        return out, seq_enc, vec, indices


def load_ae_model(tokenizer, path='./epi_ae.ckpt',):
    # tokenizer = Tokenizer()
    ## load model
    model_args = dict(
        tokenizer = tokenizer,
        dim_hid = 32,
        len_seq = 12,
    )
    model = AutoEncoder(**model_args)
    model.eval()

    ## load weights
    state_dict = torch.load(path, map_location=device)
    state_dict = {k[6:]:v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    return model


class PretrainedEncoder:
    def __init__(self, tokenizer):
        self.ae_model = load_ae_model(tokenizer)
        self.tokenizer = tokenizer

    def encode_pretrained_epi(self, epi_seqs):
        enc = self.infer(epi_seqs)
        enc_vec = enc[2]
        enc_seq = enc[-1]
        return enc_seq, enc_vec
    
    def infer(self, seqs):
        # # seqs encoding
        n_seqs = len(seqs)
        len_seqs = [len(seq) for seq in seqs]
        assert (np.max(len_seqs) <= 12) and (np.min(len_seqs)>=8), ValueError('Lengths of epitopes must be within [8, 12]')
        encoding = np.zeros([n_seqs, 12], dtype='int32')
        for i, seq in enumerate(seqs):
            len_seq = len_seqs[i]
            if len_seq == 8:
                encoding[i, 2:len_seq+2] = self.tokenizer.id_list(seq)
            elif (len_seq == 9) or (len_seq == 10):
                encoding[i, 1:len_seq+1] = self.tokenizer.id_list(seq)
            else:
                encoding[i, :len_seq] = self.tokenizer.id_list(seq)
        # # pretrained ae features
        inputs = torch.from_numpy(encoding)
        out, seq_enc, vec, indices = self.ae_model(inputs)
        out = np.argmax(out.detach().cpu().numpy(), -1)
        return [
            out,
            seq_enc.detach().cpu().numpy(),
            vec.detach().cpu().numpy(),
            indices,
            encoding
        ]
    
# Manually load Tokenizer from their params
tokenizer = torch.load('base_model.ckpt', map_location=torch.device('cpu'))['hyper_parameters']['model_args']['tokenizer']
epi_encoder = PretrainedEncoder(tokenizer)

## Load useful functions

In [2]:
import os
import glob
import numpy as np
import pandas as pd
from typing import Tuple
from sklearn.preprocessing import MinMaxScaler

In [3]:
def is_valid_peptide(peptide: str) -> bool:
    """
    Check if the peptide has valid characters
    :param peptide: str
        Peptide to check
    :return: bool
        True if peptide is valid, False otherwise
    """
    peptide = peptide.upper()
    valid_aa = 'ACDEFGHIKLMNPQRSTVWY'
    return all([aa in valid_aa for aa in peptide])

def filter_peptides_by_len(peptides_list: list,
                           min_len: int = 8, 
                           max_len: int = 12) -> Tuple[list, list]:
    """
    Get list of peptides that have min_len or more AA and
    less or fewer max_len AA (min_len <= len(peptide) <= max_len)
    :param peptides_list: list
        List of peptides (as str) to filter
    :param min_len: int
        Minimum length of peptides to keep
    :param max_len: int
        Maximum length of peptides to keep
    :return: Tuple[list, list]
        * List of peptides that have min_len or more AA and
        less or fewer max_len AA
        * List of indices of the peptides in the original list
        that passed the filter
    """
    # Filter epitopes to 8 to 12 AA
    filt_peptides = []
    index_mask = []
    n_invalid = 0
    for i, p in enumerate(peptides_list):
        if max_len >= len(p) >= min_len and is_valid_peptide(p):
            filt_peptides.append(p)
            index_mask.append(i)
        else:
            n_invalid +=1
    print(f'\t {n_invalid} peptides out of {len(seqs_epi_raw)} were not betweem {min_len} and {max_len} AA.')
    return filt_peptides, index_mask

def scale_peptide_encodings(peptide_encodings: np.ndarray) -> np.ndarray:
    """
    Scale each feature (col) of the peptide encodings to be between 0 and 1
    :param peptide_encodings: np.ndarray
        Encodings to scale
    :return: np.ndarray
        Scaled encodings
    """
    scaler = MinMaxScaler()
    return scaler.fit_transform(peptide_encodings)


# Globals

In [4]:
DATA_FOLDER = os.path.join('..', 'data')
RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'raw')
RAW_pHLA_BINDING_DATA_FOLDER = os.path.join(RAW_DATA_FOLDER, 'pHLA_binding')
INTERIM_DATA_FOLDER = os.path.join(DATA_FOLDER, 'interim')
INTERIM_pHLA_BINDING_DATA_FOLDER = os.path.join(INTERIM_DATA_FOLDER, 'pHLA_binding')
PROCESS_DATA_FOLDER = os.path.join(DATA_FOLDER, 'processed')

## Process NetMHCpan data

We will generate the train and test sets using the paper original split.

In [5]:
netmhcpan_raw_train_folder = os.path.join(RAW_pHLA_BINDING_DATA_FOLDER, 'NetMHCpan_train')
alleles_list_file = os.path.join(netmhcpan_raw_train_folder, 'allelelist')
netmhcpan_raw_test_folder = os.path.join(RAW_pHLA_BINDING_DATA_FOLDER, 'CD8_benchmark_filtered')
netmhcpan_interim_folder = os.path.join(INTERIM_pHLA_BINDING_DATA_FOLDER, 'NetMHCpan_dataset')

In [6]:
# binding affinity data
ba_files = glob.glob(f'{netmhcpan_raw_train_folder}/*_ba')
# eluted ligand data
el_files = glob.glob(f'{netmhcpan_raw_train_folder}/*el')
# Test set files
test_files = glob.glob(f'{netmhcpan_raw_test_folder}/*HLA*')

# Make dict with allelelist data for Multi-allelic data
with open(alleles_list_file, 'r') as f:
    alleles_dict = {}
    for line in f:
        line = line.strip()
        if line:
            allele, hla_list = line.split()
            hla_list = hla_list.split(',')
            alleles_dict[allele] = hla_list


In [7]:
# Process Binding Affinity data
ba_seq_pep_list = [] # List of peptide sequences for all files
ba_epi_vec_list = [] # List of epitope encodings for all files
ba_epi_labels_list = [] # List of epitope labels for all files
ba_hla_list = [] # List of HLA alleles per epitope for all files
ba_is_multi_allelic = [] # List of whether the epitope is presented by multiple alleles

for f in ba_files:
    print(f'Processing {f}')
    df = pd.read_csv(f, sep=' ', header=None, names=('epitope', 'binding_affinity', 'hla_allele'))
    df = df[df['hla_allele'].str.startswith('HLA')] # Only consider HLAs
    
    seqs_epi_raw = df['epitope'].values.tolist()
    valid_epi, valid_epi_idx = filter_peptides_by_len(seqs_epi_raw)
    _, epi_vec = epi_encoder.encode_pretrained_epi(valid_epi)
    
    df = df.iloc[valid_epi_idx] # Filter out invalid epitopes
    binding_labels = df['binding_affinity'] >= 0.426 # Above is considered a binder
    binding_labels_arr = binding_labels.to_numpy().astype(int)
    
    # Normalize HLA naming
    hla_array = df['hla_allele'].str.replace(':', '-').values
    
    ba_seq_pep_list.append(df['epitope'].values)
    ba_epi_vec_list.append(epi_vec)
    ba_epi_labels_list.append(binding_labels_arr)
    ba_hla_list.append(hla_array)
    ba_is_multi_allelic.append(np.zeros(binding_labels_arr.shape[0], dtype=bool))
    
all_seq_pep = np.concatenate(ba_seq_pep_list)
all_ba_epi_vec = np.concatenate(ba_epi_vec_list)
all_ba_epi_labels = np.concatenate(ba_epi_labels_list)
all_ba_hla = np.concatenate(ba_hla_list)
all_ba_is_multi_allelic = np.concatenate(ba_is_multi_allelic)
all_ba_epi_vec_norm = scale_peptide_encodings(all_ba_epi_vec)

assert (all_ba_epi_vec.shape[0] == all_ba_epi_labels.shape[0] == 
        all_ba_hla.shape[0] == all_ba_is_multi_allelic.shape[0]), 'Mismatch in data shapes.'

binding_affinity_df = pd.DataFrame({
    'peptide': all_seq_pep,
    'is_mono_allelic': ~all_ba_is_multi_allelic,
    'hla_allele': all_ba_hla,
    'label': all_ba_epi_labels,
    #'peptide_encoding': [vec.tolist() for vec in all_ba_epi_vec],
    #'norm_peptide_encoding': [vec.tolist() for vec in all_ba_epi_vec_norm]
})

binding_affinity_df.to_csv(os.path.join(netmhcpan_interim_folder, 'train_binding_affinity_peptides_data.csv.gz'), index=False)
np.save(os.path.join(netmhcpan_interim_folder, 'train_binding_affinity_peptides_encodings.npy'), all_ba_epi_vec)
np.save(os.path.join(netmhcpan_interim_folder, 'train_binding_affinity_peptides_encodings_norm.npy'), all_ba_epi_vec_norm)

Processing ../data/raw/pHLA_binding/NetMHCpan_train/c003_ba
	 219 peptides out of 33848 were not betweem 8 and 12 AA.
Processing ../data/raw/pHLA_binding/NetMHCpan_train/c002_ba
	 122 peptides out of 33974 were not betweem 8 and 12 AA.
Processing ../data/raw/pHLA_binding/NetMHCpan_train/c000_ba
	 82 peptides out of 33507 were not betweem 8 and 12 AA.
Processing ../data/raw/pHLA_binding/NetMHCpan_train/c004_ba
	 59 peptides out of 34613 were not betweem 8 and 12 AA.
Processing ../data/raw/pHLA_binding/NetMHCpan_train/c001_ba
	 125 peptides out of 34528 were not betweem 8 and 12 AA.
