# Explore pHLApredictor latent spaces

Check if the latent spaces of pHLApredictor models are able to capture the biological information of the input data, such as the peptide binding mode.



### Useful functions and global variables

In [None]:
import os
import sys
import glob

import torch
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.subplots as sp
import plotly.graph_objects as go

from nimbus.predictors import pHLABindingPredictor, pHLAPseudoseqBindingPredictor
from nimbus.data_processing import SeqTokenizer
from nimbus.data_processing import pHLADataset
from nimbus.utils import LoggerFactory
from nimbus.globals import DEVICE

logger = LoggerFactory.get_logger('explore_pHLApredictor_nb', 'INFO')

In [None]:
DATA_DIR = '../data'
RAW_DATA = os.path.join(DATA_DIR, 'raw')
PROCESSED_DATA = os.path.join(DATA_DIR, 'processed')
HLA_FP_DIR = os.path.join(PROCESSED_DATA, 'hla_fingerprints')
hla_fp_data_file = os.path.join(HLA_FP_DIR, 'hla_index_netMHCpan_pseudoseq_res_representation.csv')
hla_fp_400_file = os.path.join(HLA_FP_DIR, 'hla_af_patch_emb_patch_r18_pt400.npy')
hla_fp_36_file = os.path.join(HLA_FP_DIR, 'hla_fingerprint_netMHCpan_pseudoseq_res_representation.npy')
hla_pseudoseq_file = os.path.join(RAW_DATA, 'pHLA_binding', 'NetMHCpan_train', 'MHC_pseudo_fixed.dat')
test_netmhcpan_data_file = os.path.join(PROCESSED_DATA, 'pHLA_binding', 'test_set_peptides_data_MaxLenPep15_hla_ABC.csv.gz')
RND_PEPTIDES_DIR = os.path.join(PROCESSED_DATA, 'cleaved_human_proteome')
RND_PEPTIDES_FILES = glob.glob(os.path.join(RND_PEPTIDES_DIR, '*.txt'))

# load rnd peptides data into dictionary
rnd_peptides_data = {}
for file in RND_PEPTIDES_FILES:
    # Assuming filenames like random_peptides_length_9.txt
    filename = os.path.basename(file).split('.')[0]
    length = filename.split('_')[3]
    with open(file, 'r') as f:
        rnd_peptides_data[length] = f.read().splitlines()

CHECKPOINTS_DIR = '../checkpoints/csv_logger'
v_num = 0  # Version number
experiments_dict = {
    'pHLA_balance': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_balance', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_balance_hla_pseudoseq': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_balance_hla_pseudoseq', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_balance_FILIP128': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_balance_FILIP128', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq_ManSplits0123_4': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_ManSplits0123_4', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq_ManSplits0124_3': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_ManSplits0124_3', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq_ManSplits0134_2': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_ManSplits0134_2', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq_ManSplits0234_1': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_ManSplits0234_1', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq_ManSplits1234_0': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_ManSplits1234_0', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pseudoseq_pHLA_imbalance_ManSplits0123_4': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pseudoseq_pHLA_imbalance_ManSplits0123_4', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 0,
        'hla_representation_type': 'pseudoseq',
    },
    'pHLA_imbalance_newHLAFP_ManSplits0123_4': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_newHLAFP_ManSplits0123_4', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_EL_hla_pseudoseq_splitTrainTest': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_EL_hla_pseudoseq_splitTrainTest', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_EL_splitTrainTest': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_EL_splitTrainTest', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq_AllBA_TestAsVal': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_AllBA_TestAsVal', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq_AllEL_TestAsVal': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_AllEL_TestAsVal', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_newHLAFP_AllBA_TestAsVal': { # TODO test
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_newHLAFP_AllBA_TestAsVal', f'version_0', 'checkpoints','ep*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_newHLAFP_AllBA_TestAsVal_v1': { # TODO test
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_newHLAFP_AllBA_TestAsVal', f'version_1', 'checkpoints','ep*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_newHLAFP_AllEL_TestAsVal': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_newHLAFP_AllEL_TestAsVal', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'pseudoseq_pHLA_imbalance_AllBA_TestAsVal': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pseudoseq_pHLA_imbalance_AllBA_TestAsVal', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 0,
        'hla_representation_type': 'pseudoseq',
    },
    'pseudoseq_pHLA_imbalance_AllEL_TestAsVal': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pseudoseq_pHLA_imbalance_AllEL_TestAsVal', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 0,
        'hla_representation_type': 'pseudoseq',
    },
    'pHLA_imbalance_AllBA_HLAAugmented_ManSplit0123_4': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_AllBA_HLAAugmented_ManSplit0123_4', f'version_{v_num}', 'checkpoints','l*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
        'pHLA_imbalance_AllBA_HLAAugmented1Random_ManSplit0123_4': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_AllBA_HLAAugmented1Random_ManSplit0123_4', f'version_{v_num}', 'checkpoints','epoch13-val_loss0.2769-val_acc0.88.ckpt'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_AllBA_HLAAugmented1Random_ManSplit0123_4_v1': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_AllBA_HLAAugmented1Random_ManSplit0123_4', f'version_1', 'checkpoints','epoch06-val_loss0.2758-val_acc0.88.ckpt'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'Immuno_imbalance_PRIME2_ManSplit1234_0_HLAAugmented_pHLApretrained_Freeze_SA': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'Immuno_imbalance_PRIME2_ManSplit1234_0_HLAAugmented_pHLApretrained_Freeze_SA', f'version_{v_num}', 'checkpoints','epoch*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'Immuno_imbalance_PRIME2_ManSplit1234_0_HLAAugmented_pHLApretrained_Freeze_SA_CA_FILIP': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'Immuno_imbalance_PRIME2_ManSplit1234_0_HLAAugmented_pHLApretrained_Freeze_SA_CA_FILIP', f'version_{v_num}', 'checkpoints','epoch*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
}

In [None]:
def load_pretrained_model(checkpoint_file, hla_representation_type='surface_fp'):
    if hla_representation_type == 'surface_fp':
        logger.info(f"Loading pHLABindingPredictor pretrained model {checkpoint_file}")
        model = pHLABindingPredictor.load_from_checkpoint(checkpoint_file)
    elif hla_representation_type == 'pseudoseq':
        logger.info(f"Loading pHLAPseudoseqBindingPredictor pretrained model {checkpoint_file}")
        model = pHLAPseudoseqBindingPredictor.load_from_checkpoint(checkpoint_file)
    else:
        logger.error(f"Unknown hla_representation_type {hla_representation_type}. "
                     f"Expected 'surface_fp' or 'pseudoseq'")
        sys.exit(1)
    return model


In [None]:
def predict_likelihood_from_dataloader(model, dataloader, device, save_attn=False):
    all_probs = []
    all_attn_dict = []
    all_reps = []
    model.eval()
    for peptide_data_input, hla_data_input, label in dataloader:
        if save_attn:
            reps, attn_dict = model(torch.Tensor(peptide_data_input).to(device),
                      torch.Tensor(hla_data_input).to(device), save_attn=save_attn)
            all_attn_dict.append(attn_dict)
        else:
            # Only returns FILIP representation
            reps = model(torch.Tensor(peptide_data_input).to(device),
                          torch.Tensor(hla_data_input).to(device), save_attn=save_attn)
        
        logits = model.linear_to_logits(reps)
        logits = model.to_pred(logits)
        prob = torch.sigmoid(logits).detach().cpu().numpy()
        all_probs.extend(prob.tolist())
        all_reps.append(reps.detach().cpu().numpy())
        
    reps_arr = np.concatenate(all_reps, axis=0)

    if save_attn:
        # merge dicts
        merged_attns_dict = {}
        
        for attn_type in all_attn_dict[0].keys():
            if 'filip' in attn_type:
                filip_interactions = [attn_dict[attn_type] for attn_dict in all_attn_dict]
                merged_attns_dict[attn_type] = np.concatenate(filip_interactions, axis=0)
            else:
                attn_list = []
                attn_list.extend(np.array(attn_dict[attn_type]) for attn_dict in all_attn_dict)
                merged_attns_dict[attn_type] = np.concatenate(attn_list, axis=1)
        
        return all_probs, reps_arr, merged_attns_dict 
    else:
        return all_probs, reps_arr
        

In [None]:
def predict_likelihood(model, seq: str, hla_fp, hla_fp_data, device, save_attn=False):
    """Predict the likelihood of a peptide binding to a set of HLA alleles in hla_fp."""
    tokenizer = SeqTokenizer()
    pe = np.zeros(15, dtype='int32')
    pe[:len(seq)] = tokenizer.encode(seq)

    peptide_data_input = []
    hla_data_input = []
    for fp in hla_fp:
        peptide_data_input.append(pe)
        hla_data_input.append(fp)
    peptide_data_input = np.array(peptide_data_input).reshape(len(peptide_data_input), 1, -1)
    hla_data_input = np.array(hla_data_input)
    
    model.eval()
    if save_attn:
        reps, attn_dict = model(torch.Tensor(peptide_data_input).to(device),
                      torch.Tensor(hla_data_input).to(device), save_attn=save_attn)
    else:
        # Only returns FILIP representation
        reps = model(torch.Tensor(peptide_data_input).to(device),
                      torch.Tensor(hla_data_input).to(device), save_attn=save_attn)
        
    logits = model.linear_to_logits(reps)
    logits = model.to_pred(logits)
    prob = torch.sigmoid(logits).detach().cpu().numpy()
    
    hla_fp_data_rev = {v: k for k,v in hla_fp_data.items()}
    
    # Create a dictionary with the probabilities for each HLA allele. 
    # Importantly, they are not sorted by probability to keep the order of the HLA alleles. 
    prob_by_hla = {h: v for h, v in zip(hla_fp_data.keys(), prob)}

    ranked_order = prob.argsort()[::-1]
    hla_list = [hla_fp_data_rev[x] for x in ranked_order]
    p_list = [prob[x] for x in ranked_order]
    
    
    for i, (h, p) in enumerate(zip(hla_list, p_list)):
        if i>4: break
        print(f'{h}:  {p}')
    
    if save_attn:
        return prob_by_hla, reps.detach().cpu().numpy(), attn_dict 
    else:
        return prob_by_hla, reps.detach().cpu().numpy()

In [None]:
def predict_likelihood_several_peptides(model, seq: list, hla_fp, hla_fp_data, device, save_attn=False):
    """Predict the likelihood of list of peptide binding to a set of HLA alleles in hla_fp."""
    tokenizer = SeqTokenizer()
    peptide_data_input = []
    hla_data_input = []

    for s in seq:
        pe = np.zeros(15, dtype='int32')
        pe[:len(s)] = tokenizer.encode(s)
        for fp in hla_fp:
            peptide_data_input.append(pe)
            hla_data_input.append(fp)
    peptide_data_input = np.array(peptide_data_input).reshape(len(peptide_data_input), 1, -1)
    hla_data_input = np.array(hla_data_input)
    
    model.eval()
    if save_attn:
        reps, attn_dict = model(torch.Tensor(peptide_data_input).to(device),
                      torch.Tensor(hla_data_input).to(device), save_attn=save_attn)
    else:
        # Only returns FILIP representation
        reps = model(torch.Tensor(peptide_data_input).to(device),
                      torch.Tensor(hla_data_input).to(device), save_attn=save_attn)
        
    logits = model.linear_to_logits(reps)
    logits = model.to_pred(logits)
    prob = torch.sigmoid(logits).detach().cpu().numpy()
    
    hla_fp_data_rev = {v: k for k,v in hla_fp_data.items()}
    
    all_peptides_prob_by_phla= {}
    # Print the top 5 HLA alleles for each peptide
    for n_pep in range(len(seq)):
        print(f'Peptide: {seq[n_pep]}')
        initial_idx = n_pep * len(hla_fp)
        last_idx = initial_idx + len(hla_fp)
        pep_prob = prob[initial_idx:last_idx]
        # Create a dictionary with the probabilities for each HLA allele. 
        # Importantly, they are not sorted by probability to keep the order of the HLA alleles. 
        prob_by_phla = {f'{seq[n_pep]}_{h}': v for h, v in zip(hla_fp_data.keys(), pep_prob)}
        ranked_order = pep_prob.argsort()[::-1]

        hla_list = [hla_fp_data_rev[x] for x in ranked_order]
        p_list = [pep_prob[x] for x in ranked_order]
        for i, (h, p) in enumerate(zip(hla_list, p_list)):
            if i>4: break
            print(f'{h}:  {p}')
        all_peptides_prob_by_phla.update(prob_by_phla)
    
    if save_attn:
        return all_peptides_prob_by_phla, reps.detach().cpu().numpy(), attn_dict 
    else:
        return all_peptides_prob_by_phla, reps.detach().cpu().numpy()


In [None]:
def select_pool_hla_fp(list_hla_names, hla_fp, hla_fp_data):
    """Returns a shortlist of hla_fp and hla_fp_data for a list of HLA alleles."""
    hla_fp_shortlist = []
    hla_fp_data_shortlist = {}
    for hla_name in list_hla_names:
        hla_fp_shortlist.append(hla_fp[hla_fp_data[hla_name]])
        hla_fp_data_shortlist[hla_name] = len(hla_fp_shortlist) - 1
    return np.array(hla_fp_shortlist), hla_fp_data_shortlist

## Case study 1:

A molecular switch in immunodominant HIV-1-specific CD8 T-cell epitopes shapes differential HLA-restricted escape - Retrovirology https://link.springer.com/article/10.1186/s12977-015-0149-5

### Info from the paper

We studied four members of the HLA-B7 superfamily, HLA-B\*07:02, HLA-B\*42:01, HLA-B\*42:02 and HLA-B\*81:01. These closely-related HLAI molecules restrict both distinct and identical HIV-1 epitopes
The differences in sequence between HLA-B\*07:02, B\*42:01, B\*42:02 and B\*81:01 are small [...] all of the polymorphic positions are part of the HLA peptide binding pockets and therefore potentially affect peptide presentation to CD8+ T cells.

We focused on two epitopes, TL9-p24 and RM9-Nef, that dominate the HIV-1-specific CD8 T-cell response in the Southern African HIV-1 epidemic.
These two epitopes were presented by all four of these HLA molecules, other than TL9-p24 which was presented by all except HLA-B\*42:02 (Figure 2A). TL9-p24, dominantly targeted through HLA-B\*42:01 and HLA-B\*81:01 and subdominantly through HLA-B\*07:02 [...].
RM9-Nef, however, was presented by all 4 different HLAs.

[...]

TL9-p24 exhibits a unique conformation when presented by HLA-B*81:01.

### Summary

#### TL9-p24
Peptide: `TPQDLNTML`
HLA alleles in which is presented:
- HLA-B\*42:01
- HLA-B\*81:01 (dominant, shows a unique peptide binding mode)
- HLA-B\*07:02 (subdominant)

HLA alleles in which is not presented:
- HLA-B\*42:02

#### RM9-Nef
Peptide: `RPQVPLRPM`
HLA alleles in which is presented:
- HLA-B\*07:02
- HLA-B\*42:01
- HLA-B\*42:02
- HLA-B\*81:01



In [None]:
exp_name = 'Immuno_imbalance_PRIME2_ManSplit1234_0_HLAAugmented_pHLApretrained_Freeze_SA_CA_FILIP'#'pHLA_imbalance_AllBA_HLAAugmented_ManSplit0123_4' #'pHLA_imbalance_AllBA_HLAAugmented_ManSplit0123_4'#'pHLA_imbalance_newHLAFP_ManSplits0123_4'#'pHLA_imbalance_AllBA_HLAAugmented_ManSplit0123_4' #'pHLA_imbalance_hla_pseudoseq_ManSplits1234_0' #'pHLA_balance_hla_pseudoseq'
hla_representation_type = experiments_dict[exp_name]['hla_representation_type']

if hla_representation_type == 'surface_fp':
    if 400 == experiments_dict[exp_name]['hla_fp_size']:
        hla_fp_file = hla_fp_400_file
    elif 36 == experiments_dict[exp_name]['hla_fp_size']:
        hla_fp_file = hla_fp_36_file

    hla_fp = np.load(hla_fp_file)
    hla_fp_data = pd.read_csv(hla_fp_data_file,
                              index_col=1,
                              names=['index'],
                              header=0).to_dict()['index']
elif hla_representation_type == 'pseudoseq':
    hla_fp_data = pd.read_csv(hla_fp_data_file,
                              index_col=1,
                              names=['index'],
                              header=0).to_dict()['index']
    hla_allels_subset = list(hla_fp_data.keys())
    del hla_fp_data
    hla_pseudoseq_dict = pd.read_csv(hla_pseudoseq_file, sep='\s+', names=['pseudoseq'], header=None).to_dict()['pseudoseq']
    # filter hla_pseudoseq_dict to only include the alleles in hla_allels_subset
    hla_pseudoseq_dict = {hla: seq for hla, seq in hla_pseudoseq_dict.items() if hla in hla_allels_subset}
    seq_tokenizer = SeqTokenizer()
    hla_fp_dict = {hla: torch.Tensor(seq_tokenizer.encode(hla_pseudoseq_dict[hla]))
                   for hla in hla_pseudoseq_dict.keys()}
    hla_fp = list(hla_fp_dict.values())
    # make hla_fp_data compatible with the rest of the code. It is a dictionary with the same keys as hla_fp and indexes as values
    hla_fp_data = {hla: i for i, hla in enumerate(hla_fp_dict.keys())}
    
else:
    raise NotImplementedError(f"Unknown hla_representation_type {hla_representation_type}. Expected 'surface_fp' or 'pseudoseq'")

model = load_pretrained_model(experiments_dict[exp_name]['model_checkpoint'],
                              hla_representation_type=hla_representation_type)

#### TL9-p24
##### NetMHCpan baseline
They cannot differentiate between HLA-B\*42:01 and HLA-B\*42:02 because HLA-B\*42:02 predictions are based on the nearest neighbor HLA-B\*42:01. 
```
# NetMHCpan version 4.1b

# Tmpdir made /var/www/html/services/NetMHCpan-4.1/tmp/netMHCpanvigqPD
# Input is in FSA format

# Peptide length 9

# Make EL predictions

HLA-B42:01 : Distance to training data  0.000 (using nearest neighbor HLA-B42:01)
HLA-B81:01 : Distance to training data  0.146 (using nearest neighbor HLA-B42:01)
HLA-B07:02 : Distance to training data  0.000 (using nearest neighbor HLA-B07:02)
HLA-B42:02 : Distance to training data  0.028 (using nearest neighbor HLA-B42:01)

# Rank Threshold for Strong binding peptides   0.500
# Rank Threshold for Weak binding peptides   2.000
---------------------------------------------------------------------------------------------------------------------------
 Pos         MHC        Peptide      Core Of Gp Gl Ip Il        Icore        Identity  Score_EL %Rank_EL BindLevel
---------------------------------------------------------------------------------------------------------------------------
   1 HLA-B*42:01      TPQDLNTML TPQDLNTML  0  0  0  0  0    TPQDLNTML        Sequence 0.9286300    0.069 <= SB
   1 HLA-B*81:01      TPQDLNTML TPQDLNTML  0  0  0  0  0    TPQDLNTML        Sequence 0.7545360    0.040 <= SB
   1 HLA-B*07:02      TPQDLNTML TPQDLNTML  0  0  0  0  0    TPQDLNTML        Sequence 0.7424500    0.189 <= SB
   1 HLA-B*42:02      TPQDLNTML TPQDLNTML  0  0  0  0  0    TPQDLNTML        Sequence 0.6575040    0.076 <= SB
---------------------------------------------------------------------------------------------------------------------------
```

In [None]:
seq = 'TPQDLNTML'
save_attention = True
if save_attention:
    prob_by_hla, reps, attns_dict = predict_likelihood(model, seq, hla_fp, hla_fp_data, DEVICE, save_attn=save_attention)
else:
    prob_by_hla, reps = predict_likelihood(model, seq, hla_fp, hla_fp_data, DEVICE)

In [None]:
print('HLA of interest and their probabilities:')
print('\t Binders:')
print('HLA-B*42:01:', prob_by_hla['HLA-B42-01'])
print('HLA-B*81:01:', prob_by_hla['HLA-B81-01'])
print('HLA-B*07:02:', prob_by_hla['HLA-B07-02'], '(subdominant)')
print('\t Non-binders:')
print('HLA-B*42:02:', prob_by_hla['HLA-B42-02'])

In [None]:
reps.shape

In [None]:
px.bar(x=prob_by_hla.keys(), y=prob_by_hla.values(), range_y=[0, 1])

In [None]:
import umap.umap_ as umap

umap = umap.UMAP(
    n_neighbors=15,
    min_dist=0.1,
)

coords = umap.fit_transform(reps.mean(1))

In [None]:
fig = go.Figure()

symbol_dict = {'A': 'circle', 'B': 'x', 'C': 'diamond'}
size_dict = {'A': 10, 'B': 12, 'C': 12}

fig.add_traces(
    go.Scatter(
        x = coords[:, 0],
        y = coords[:, 1],
        mode='markers',
        marker = dict(
            color = list(prob_by_hla.values()),
            colorscale='Portland_r',
            symbol = [symbol_dict[x.split('-')[1][0]] for x in list(prob_by_hla.keys())],
            size = [size_dict[x.split('-')[1][0]] for x in list(prob_by_hla.keys())],
            showscale = True,
            opacity = [1 if h in ['HLA-B42-01', 'HLA-B42-02', 'HLA-B81-01', 'HLA-B07-02'] else 0.3 for h in prob_by_hla.keys()],
            colorbar=dict(
                len=0.75,
                title_text='binding score',
                xanchor="right", x=1.3,
                yanchor='bottom', y=0.1,
                thickness=20,
            ),
            line=dict(width=1)
        ),
        hovertext = [f'name: {n}<br>prob: {p}' for n, p in zip(list(prob_by_hla.keys()), prob_by_hla.values())]
    )
)

fig.update_layout(
    template='simple_white',
    # showlegend=True,
    font={'family': "Arial"}, width=800, height=800,
    xaxis_title='UMAP 1',
    yaxis_title='UMAP 2',

    )



In [None]:
for attn_type, attn_arr in attns_dict.items():
    if attn_type.startswith('filip'):
        print(f'{attn_type} has {len(attn_arr)} BatchSize each batch with shape {attn_arr[0].shape}. So the object shape is {attn_arr.shape}')        
    else:
        print(f'{attn_type} has {len(attn_arr)} layers each with shape {attn_arr[0].shape} each') 

In [None]:
filip_interactions = attns_dict['filip_interactions']
hla_name = 'HLA-B42-01'
hla_index = hla_fp_data[hla_name]
head_idx = [2,4,5,11,12,13, 50, 63]
#head_idx = [124, 125, 126, 127]

# Find the global min and max values across all the heatmaps
global_min = min(hla.min() for hla in filip_interactions[hla_index][head_idx])
global_max = max(hla.max() for hla in filip_interactions[hla_index][head_idx])

fig = sp.make_subplots(rows=1, cols=len(head_idx), subplot_titles=[f'{hla_name}_H{i}' for i in head_idx])

for i, hla in enumerate(filip_interactions[hla_index][head_idx]):
    fig.add_trace(
        go.Heatmap(
            z=hla,
            colorscale='Portland_r',
            zmin=global_min,  # Set the same minimum value
            zmax=global_max   # Set the same maximum value
        ),
        row=1, col=i+1
    )

fig.update_layout(
    height=800,
    width=150 * len(head_idx),
    showlegend=False,
    title_text="HLA Interactions"
)

# Update the x-axis and y-axis labels for each subplot
for i in range(len(head_idx)):
    fig.update_xaxes(title_text="Peptide Sequence", row=1, col=i+1)
fig.update_yaxes(title_text="HLA Fingerprint", row=1, col=1)

fig.update_layout(
    template='simple_white',
    # showlegend=True,
    font={'family': "Arial"},
    )
fig.show()

In [None]:
attn_layer = 3
pep_cross_attn = attns_dict['pep_cross_attn'][attn_layer] # Shape: (num_layers, batch_size, num_heads, hla_n_fp, pep_seq_len)
hla_name = 'HLA-B42-01'
hla_index = hla_fp_data[hla_name]
head_idx = [0,1,2,3,4,5,6,7]
#head_idx = [124, 125, 126, 127]

# Find the global min and max values across all the heatmaps
global_min = min(hla.min() for hla in pep_cross_attn[hla_index][head_idx])
global_max = max(hla.max() for hla in pep_cross_attn[hla_index][head_idx])

fig = sp.make_subplots(rows=1, cols=len(head_idx), subplot_titles=[f'Layer{attn_layer}_Head{i}' for i in head_idx])

for i, hla in enumerate(pep_cross_attn[hla_index][head_idx]):
    fig.add_trace(
        go.Heatmap(
            z=hla,
            colorscale='Portland_r',
            zmin=global_min,  # Set the same minimum value
            zmax=global_max   # Set the same maximum value
        ),
        row=1, col=i+1
    )

fig.update_layout(
    height=800,
    width=150 * len(head_idx),
    showlegend=False,
    title_text=f"Peptide Cross Attn with {hla_name}"
)

# Update the x-axis and y-axis labels for each subplot
for i in range(len(head_idx)):
    fig.update_xaxes(title_text="Peptide Sequence", row=1, col=i+1)
fig.update_yaxes(title_text="HLA Fingerprint", row=1, col=1)

fig.update_layout(
    template='simple_white',
    # showlegend=True,
    font={'family': "Arial"},
    )
fig.show()

In [None]:
# This might give information of which peptide residues are pointing towards
# the inner binding groove of the HLA

attn_layer = 3
pep_cross_attn = attns_dict['hla_cross_attn'][attn_layer] # Shape: (num_layers, batch_size, num_heads, hla_n_fp, pep_seq_len)
hla_name = 'HLA-B42-01'
hla_index = hla_fp_data[hla_name]
head_idx = [0,1,2,3,4,5,6,7]
#head_idx = [124, 125, 126, 127]

# Find the global min and max values across all the heatmaps
global_min = min(hla.min() for hla in pep_cross_attn[hla_index][head_idx])
global_max = max(hla.max() for hla in pep_cross_attn[hla_index][head_idx])

fig = sp.make_subplots(rows=1, cols=len(head_idx), subplot_titles=[f'Layer{attn_layer}_Head{i}' for i in head_idx])

for i, hla in enumerate(pep_cross_attn[hla_index][head_idx]):
    fig.add_trace(
        go.Heatmap(
            z=hla,
            colorscale='Portland_r',
            zmin=global_min,  # Set the same minimum value
            zmax=global_max   # Set the same maximum value
        ),
        row=1, col=i+1
    )

fig.update_layout(
    height=800,
    width=150 * len(head_idx),
    showlegend=False,
    title_text=f"HLA Cross Attn with {hla_name}"
)

# Update the x-axis and y-axis labels for each subplot
for i in range(len(head_idx)):
    fig.update_xaxes(title_text="Peptide Sequence", row=1, col=i+1)
fig.update_yaxes(title_text="HLA Fingerprint", row=1, col=1)

fig.update_layout(
    template='simple_white',
    # showlegend=True,
    font={'family': "Arial"},
    )
fig.show()

In [None]:
# This might give information of which peptide residues are pointing towards
# the inner binding groove of the HLA

attn_layer = 0
pep_cross_attn = attns_dict['hla_self_attn'][attn_layer] # Shape: (num_layers, batch_size, num_heads, hla_n_fp, pep_seq_len)
hla_name = 'HLA-B42-01'
hla_index = hla_fp_data[hla_name]
head_idx = [0,1,2,3,4,5,6,7]
#head_idx = [124, 125, 126, 127]

# Find the global min and max values across all the heatmaps
global_min = min(hla.min() for hla in pep_cross_attn[hla_index][head_idx])
global_max = max(hla.max() for hla in pep_cross_attn[hla_index][head_idx])

fig = sp.make_subplots(rows=1, cols=len(head_idx), subplot_titles=[f'Layer{attn_layer}_Head{i}' for i in head_idx])

for i, hla in enumerate(pep_cross_attn[hla_index][head_idx]):
    fig.add_trace(
        go.Heatmap(
            z=hla,
            colorscale='Portland_r',
            zmin=global_min,  # Set the same minimum value
            zmax=global_max   # Set the same maximum value
        ),
        row=1, col=i+1
    )

fig.update_layout(
    height=500,
    width=500 * len(head_idx),
    showlegend=False,
    title_text=f"HLA Self-Attn with {hla_name}"
)

# Update the x-axis and y-axis labels for each subplot
for i in range(len(head_idx)):
    fig.update_xaxes(title_text="Peptide Sequence", row=1, col=i+1)
fig.update_yaxes(title_text="HLA Fingerprint", row=1, col=1)

fig.update_layout(
    template='simple_white',
    # showlegend=True,
    font={'family': "Arial"},
    )
fig.show()

In [None]:
# This might give information of which peptide residues are pointing towards
# the inner binding groove of the HLA

attn_layer = 1
pep_cross_attn = attns_dict['pep_self_attn'][attn_layer] # Shape: (num_layers, batch_size, num_heads, hla_n_fp, pep_seq_len)
hla_name = 'HLA-B42-01'
hla_index = hla_fp_data[hla_name]
head_idx = [0,1,2,3,4,5,6,7]

# Find the global min and max values across all the heatmaps
global_min = min(hla.min() for hla in pep_cross_attn[hla_index][head_idx])
global_max = max(hla.max() for hla in pep_cross_attn[hla_index][head_idx])

fig = sp.make_subplots(rows=1, cols=len(head_idx), subplot_titles=[f'Layer{attn_layer}_Head{i}' for i in head_idx])

for i, hla in enumerate(pep_cross_attn[hla_index][head_idx]):
    fig.add_trace(
        go.Heatmap(
            z=hla,
            colorscale='Portland_r',
            zmin=global_min,  # Set the same minimum value
            zmax=global_max   # Set the same maximum value
        ),
        row=1, col=i+1
    )

fig.update_layout(
    height=500,
    width=500 * len(head_idx),
    showlegend=False,
    title_text=f"Peptide Self-Attn with {hla_name}"
)

# Update the x-axis and y-axis labels for each subplot
for i in range(len(head_idx)):
    fig.update_xaxes(title_text="Peptide Sequence", row=1, col=i+1)
fig.update_yaxes(title_text="HLA Fingerprint", row=1, col=1)

fig.update_layout(
    template='simple_white',
    # showlegend=True,
    font={'family': "Arial"},
    )
fig.show()

#### RM9-Nef
##### NetMHCpan baseline
```
# NetMHCpan version 4.1b

# Tmpdir made /var/www/html/services/NetMHCpan-4.1/tmp/netMHCpandu0tne
# Input is in FSA format

# Peptide length 9

# Make EL predictions

HLA-B42:01 : Distance to training data  0.000 (using nearest neighbor HLA-B42:01)
HLA-B81:01 : Distance to training data  0.146 (using nearest neighbor HLA-B42:01)
HLA-B07:02 : Distance to training data  0.000 (using nearest neighbor HLA-B07:02)
HLA-B42:02 : Distance to training data  0.028 (using nearest neighbor HLA-B42:01)

# Rank Threshold for Strong binding peptides   0.500
# Rank Threshold for Weak binding peptides   2.000
---------------------------------------------------------------------------------------------------------------------------
 Pos         MHC        Peptide      Core Of Gp Gl Ip Il        Icore        Identity  Score_EL %Rank_EL BindLevel
---------------------------------------------------------------------------------------------------------------------------
   1 HLA-B*42:01      RPQVPLRPM RPQVPLRPM  0  0  0  0  0    RPQVPLRPM        Sequence 0.9776250    0.022 <= SB
   1 HLA-B*81:01      RPQVPLRPM RPQVPLRPM  0  0  0  0  0    RPQVPLRPM        Sequence 0.7529500    0.040 <= SB
   1 HLA-B*07:02      RPQVPLRPM RPQVPLRPM  0  0  0  0  0    RPQVPLRPM        Sequence 0.9877480    0.010 <= SB
   1 HLA-B*42:02      RPQVPLRPM RPQVPLRPM  0  0  0  0  0    RPQVPLRPM        Sequence 0.8408470    0.020 <= SB
---------------------------------------------------------------------------------------------------------------------------
```

In [None]:
seq = 'RPQVPLRPM'
prob_by_hla, reps = predict_likelihood(model, seq, hla_fp, hla_fp_data, DEVICE)

In [None]:
print('HLA of interest and their probabilities:')
print('\t Binders:')
print('HLA-B*07:02:', prob_by_hla['HLA-B07-02'])
print('HLA-B*42:01:', prob_by_hla['HLA-B42-01'])
print('HLA-B*42:02:', prob_by_hla['HLA-B42-02'])
print('HLA-B*81:01:', prob_by_hla['HLA-B81-01'])

In [None]:
px.bar(x=prob_by_hla.keys(), y=prob_by_hla.values(), range_y=[0, 1])

In [None]:
coords = umap.fit_transform(reps.mean(1))

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

symbol_dict = {'A': 'circle', 'B': 'x', 'C': 'diamond'}
size_dict = {'A': 10, 'B': 12, 'C': 12}

fig.add_traces(
    go.Scatter(
        x = coords[:, 0],
        y = coords[:, 1],
        mode='markers',
        marker = dict(
            color = list(prob_by_hla.values()),
            colorscale='Portland_r',
            symbol = [symbol_dict[x.split('-')[1][0]] for x in list(prob_by_hla.keys())],
            size = [size_dict[x.split('-')[1][0]] for x in list(prob_by_hla.keys())],
            showscale = True,
            opacity = [1 if h in ['HLA-B42-01', 'HLA-B42-02', 'HLA-B81-01', 'HLA-B07-02'] else 0.3 for h in prob_by_hla.keys()],
            colorbar=dict(
                len=0.75,
                title_text='binding score',
                xanchor="right", x=1.3,
                yanchor='bottom', y=0.1,
                thickness=20,
            ),
            line=dict(width=1)
        ),
        hovertext = [f'name: {n}<br>prob: {p}' for n, p in zip(list(prob_by_hla.keys()), prob_by_hla.values())]
    )
)

fig.update_layout(
    template='simple_white',
    # showlegend=True,
    font={'family': "Arial"}, width=800, height=800,
    )

## Case study 2:

Pymm, P., Illing, P., Ramarathinam, S. et al. MHC-I peptides get out of the groove and enable a novel mechanism of HIV-1 escape. Nat Struct Mol Biol 24, 387–394 (2017). https://doi.org/10.1038/nsmb.3381 


Structures of HLA-B*57:01 presenting N-terminally extended peptides, including the immunodominant HIV-1 Gag epitope TW10 (TSTLQEQIGW), showed that the N terminus protrudes from the peptide-binding groove. The common escape mutant TSNLQEQIGW bound HLA-B*57:01 canonically, adopting a dramatically different conformation than the TW10 peptide. This affected recognition by killer cell immunoglobulin-like receptor (KIR) 3DL1 expressed on NK cells.


### Summary

HLA allele: HLA-B*57:01

- TW10 Peptide: `TSTLQEQIGW`
- T3N Peptide:  `TSNLQEQIGW` (Different binding mode. Escapes immune response)


Crystals of interest:
- 5T6Z: HLA-B*57:01 Pep: TW10
- 5T6Z: HLA-B*57:01 Pep: TW10
- 5T70: HLA-B*57:01 Pep: T3N
- 5V5L: HLA-B*58:01 Pep: TW10
24 more pHLA for HLA-B*57:01 (check [here](https://github.com/annadiarov/seq2HLAallele/blob/master/databases/mhc1_pdb/prep_mhc1_pdb_list.csv)) 

In [None]:
phla_list = [
    'HLA-B57-01', # From paper
    'HLA-B58-01', # Has crystal with TW10. Should bind with T3N according to BA
    'HLA-B58-02', # It's a binder (BA)
    'HLA-B57-03', # It's a binder (BA)
]
seq_list = ['TSTLQEQIGW', 'TSNLQEQIGW']
pdb_data_file = '/home/bsccns/epfl/seq2HLAallele/databases/mhc1_pdb/prep_mhc1_pdb_list.csv'

selected_fp, selected_fp_data = select_pool_hla_fp(phla_list, hla_fp, hla_fp_data)

In [None]:
pdb_data = pd.read_csv(pdb_data_file)
hla_b5701_uniq_pep = pdb_data[pdb_data.MHC_Name == 'B*57:01'].Epitope.unique()
hla_b5801_uniq_pep = pdb_data[pdb_data.MHC_Name == 'B*58:01'].Epitope.unique()

print('HLA-B57-01 unique peptides with PDB:', len(hla_b5701_uniq_pep))
print('HLA-B58-01 unique peptides with PDB:', len(hla_b5801_uniq_pep))
print('Common peptides:', set(hla_b5701_uniq_pep).intersection(set(hla_b5801_uniq_pep)))

In [None]:
prob_by_hla, reps, attns_dict = predict_likelihood_several_peptides(model, hla_b5701_uniq_pep, selected_fp, selected_fp_data, DEVICE, save_attn=True)

In [None]:
import umap.umap_ as umap

umap = umap.UMAP(
    n_neighbors=15,
    min_dist=0.1,
)

coords = umap.fit_transform(reps.mean(1))

In [None]:
fig = go.Figure()

symbol_dict = {'A': 'circle', 'B': 'x', 'C': 'diamond'}
size_dict = {'A': 10, 'B': 12, 'C': 12}

fig.add_traces(
    go.Scatter(
        x = coords[:, 0],
        y = coords[:, 1],
        mode='markers',
        marker = dict(
            color = list(prob_by_hla.values()),
            colorscale='Portland_r',
            symbol = [symbol_dict[x.split('-')[1][0]] for x in list(prob_by_hla.keys())],
            size = [size_dict[x.split('-')[1][0]] for x in list(prob_by_hla.keys())],
            showscale = True,
            opacity = [1 if seq_list[0] in h or seq_list[1] in h else 0.3 for h in prob_by_hla.keys()],
            colorbar=dict(
                len=0.75,
                title_text='binding score',
                xanchor="right", x=1.3,
                yanchor='bottom', y=0.1,
                thickness=20,
            ),
            line=dict(width=1)
        ),
        hovertext = [f'name: {n}<br>prob: {p}' for n, p in zip(list(prob_by_hla.keys()), prob_by_hla.values())]
    )
)

fig.update_layout(
    template='simple_white',
    font={'family': "Arial"}, width=800, height=800,
    xaxis_title='UMAP 1',
    yaxis_title='UMAP 2',

    )

## Case study 3:

PRIME2 immunogenic peptides.

In [None]:
prime2_data_file = os.path.join(PROCESSED_DATA, 'immunogenicity', 'PRIME', 'train_PRIME2_2023_peptides_hla_ABC_with_BalancedSplits_remove2HLA.csv')
prime2_data = pd.read_csv(prime2_data_file)
prime2_data.head()

In [None]:
TEST_ON_IMMUNOGENIC_ONLY = False # If True, only immunogenic peptides will be used for testing

if 'ManSplit' in exp_name:
    test_split = exp_name.split('ManSplit')[1].split('_')[1]
    test_imuno_data = prime2_data[prime2_data['split'] == int(test_split)]
    logger.info(f'Test split: {test_split}. Using {len(test_imuno_data)} peptides as test data.')
else:
    test_imuno_data = prime2_data
    logger.info('Using all data as test data.')
    
if TEST_ON_IMMUNOGENIC_ONLY:
    logger.warn('Testing only on immunogenic peptides.')
    test_imuno_data = test_imuno_data[test_imuno_data.label == 1]
    logger.info(f'Test data has {len(test_imuno_data)} immunogenic peptides, {len(test_imuno_data.hla_allele.unique())} unique HLA alleles.')
else:
    logger.info(f'Test data has {len(test_imuno_data)} peptides, {len(test_imuno_data.hla_allele.unique())} unique HLA alleles. There are {len(test_imuno_data[test_imuno_data.label == 1])} immunogenic pairs and {len(test_imuno_data[test_imuno_data.label == 0])} non-immunogenic pairs.')

In [None]:
USE_AUGMENTED_HLA_DATA = False  # If the FP provided are augmented
USE_ALL_AUGMENTED_HLA_FP_DATA = False # Whether to use all augmented FP or just pick one randomly

hla_fp_dict = {hla: torch.Tensor(hla_fp[idx]) for hla, idx in hla_fp_data.items()}

test_immuno_dataset = pHLADataset(
            peptide_seq_arr=test_imuno_data['peptide'].values,
            hla_names_arr=test_imuno_data['hla_allele'].values,
            hla_fp_dict=hla_fp_dict,
            labels=test_imuno_data['label'].values,
            has_augmented_hla=USE_AUGMENTED_HLA_DATA,
            use_all_augmented_data=USE_ALL_AUGMENTED_HLA_FP_DATA
        )

In [None]:
test_immuno_loader = torch.utils.data.DataLoader(test_immuno_dataset, batch_size=64, shuffle=False, num_workers=4)

In [None]:
# all_probs, reps, attns_dict = predict_likelihood_from_dataloader(model, test_immuno_loader, DEVICE, save_attn=True)
all_probs, reps = predict_likelihood_from_dataloader(model, test_immuno_loader, DEVICE, save_attn=False)

In [None]:
for attn_type, attn_arr in attns_dict.items():
    if attn_type.startswith('filip'):
        print(f'{attn_type} has {len(attn_arr)} BatchSize each batch with shape {attn_arr[0].shape}. So the object shape is {attn_arr.shape}')        
    else:
        print(f'{attn_type} has {len(attn_arr)} layers each with shape {attn_arr[0].shape} each') 

In [None]:
import umap.umap_ as umap

umap = umap.UMAP(
    n_neighbors=15,
    min_dist=0.1,
)

coords = umap.fit_transform(reps.mean(1))


In [None]:
fig = go.Figure()

symbol_dict = {'A': 'circle', 'B': 'x', 'C': 'diamond'}
size_dict = {'A': 10, 'B': 12, 'C': 12}

fig.add_traces(
    go.Scatter(
        x = coords[:, 0],
        y = coords[:, 1],
        mode='markers',
        marker = dict(
            color = list(all_probs),
            colorscale='Portland_r',
            symbol = [symbol_dict[hla.split('-')[1][0]] for hla in test_imuno_data.hla_allele],
            size = [size_dict[hla.split('-')[1][0]] for hla in test_imuno_data.hla_allele],
            showscale = True,
            opacity = [1 if label == 1 else 0.3 for label in test_imuno_data.label],
            colorbar=dict(
                len=0.75,
                title_text='binding score',
                xanchor="right", x=1.3,
                yanchor='bottom', y=0.1,
                thickness=20,
            ),
            line=dict(width=1)
        ),
        hovertext = [f'label: {label}<br>name: {hla}-{pep}<br>prob: {prob}' for label, prob, pep, hla in zip(list(test_imuno_data.label),all_probs, list(test_imuno_data.peptide), list(test_imuno_data.hla_allele))]
    )
)

fig.update_layout(
    template='simple_white',
    # showlegend=True,
    font={'family': "Arial"}, width=800, height=800,
    xaxis_title='UMAP 1',
    yaxis_title='UMAP 2',

    )