In [1]:
import os
import pandas as pd
import numpy as np
import regex as re
from collections import defaultdict
from typing import Tuple, List, NewType
from tqdm.notebook import tqdm_notebook as tqdm

import matplotlib.pyplot as plt
import seaborn as sns 
# %config InlineBackend.figure_format = 'retina'

import Bio
from Bio.PDB import *
import warnings
warnings.filterwarnings('ignore')






In [2]:
BioStructure = NewType('BioStructure', Bio.PDB.Structure.Structure)
BioVector = NewType('BioVector', Bio.PDB.vectors.Vector)  
BioResidue = NewType('BioResidue', Bio.PDB.Residue.Residue)  


In [96]:
#######################################################
# Load EpitopeDB
#######################################################
def load_epitopedb():
    """This function loads the Epitopes from different tables
    """
    desa = pd.read_pickle('../data/20201105_EpitopevsHLA_distance.pickle')
    return desa

#######################################################
# Load PDB file
#######################################################
def load_hla_structure(HLA_Molecule:str, path):
    parser = PDBParser()
    return parser.get_structure(HLA_Molecule, path)

#######################################################
# HLA to file name
#######################################################
def hla_to_filename(hla:str):
    """ """
    locus, specificity = hla.split('*')
    filename = '_'.join([locus, *specificity.split(':')]) + '_V1.pdb'
    return re.split('\d', locus)[0], filename
    
#######################################################
# Find HLA molecule path
#######################################################
def find_molecule_path(locus:str, filename:str) -> str:
    """This function makes use of the locus and filename resulted from 'hla_to_filename' function 
        to find the path to the relevant file .pdb file """
    
    path = os.path.expanduser(f'../data/HLAMolecule/{locus[0:2]}') # get until the first 2 character of locus if exist
    pdb_files = [file for file in os.listdir(path) if filename.split('_V1.pdb')[0] in file ]
    if len(pdb_files) != 0:
        return  True, os.path.join(path, f'{pdb_files[0]}')
    else:
        return  False, 'No path exists'

#######################################################
# Residue in short
#######################################################
def res_short(residue:BioResidue) -> str:
    """ Gets the a residue object and returns a short residue sequence_number + amino acide name code """
    
    resname = residue.get_resname()  # Residue Name
    res_code = Aminoacid_conversion.get(resname)
    res_num = residue.get_id()[1]  # Residue Number 
    return str(res_num) + res_code

Aminoacid_conversion = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
                     'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 
                     'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 
                     'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}

#######################################################
# Find the average location of Residue
#######################################################
def get_residue_avg_coord(residue:Tuple[int,str], structure:BioStructure, chain:str) -> BioVector:
    """ This function finds the average coordinate of residue by averaging all the atoms coordinates"""
    
    BioChain = structure[0][chain]
    res_num, res_code = int(residue[0]), residue[1]
    _residue = BioChain[res_num]
#     print(_residue.get_full_id())
    res_pdb = Aminoacid_conversion.get(_residue.get_resname(), 'Corresponding code of the amino-acide could not be found')
    try:
        assert res_code == res_pdb
    except AssertionError as e:
         logger.warning(f'Expected residue {res_code}, but got {res_pdb}, sequence number: {res_num}, chain: {chain},  HLA: {structure.get_id()}')
    atoms_coord = [atom.get_vector() for atom in _residue.get_atoms()]
    return np.array(atoms_coord).sum()/len(atoms_coord)

#######################################################
# Find the average location of Epitope
#######################################################
def get_epitope_avg_coord(Epitope:List[Tuple[int,str]], structure:BioStructure, HLA_chain:str) -> BioVector:
    """ This function finds the average coordinate of Epitope by averaging all the residues average coordinates"""

    residues_coord = [get_residue_avg_coord(residue, structure, HLA_chain) for residue in Epitope]
    return np.array(residues_coord).sum()/len(residues_coord)

#######################################################
# chain functions for calculating distances
#######################################################
def get_location(poly_residues:List[str], structure:BioStructure, locus:str) -> int:
    """ Locus:['A', 'B', 'C', 'DR', 'DQ'] should be max 2 letters
    """
    
    HLA_chain = {'A': 'A', 'B': 'A', 'C': 'A', 'DRB': 'B', 'DQA': 'A', 'DQB':'B'}
    epitope_coord = get_epitope_avg_coord(poly_residues, structure, HLA_chain.get(locus))
    return epitope_coord

def find_locations(EpitopeDB:pd.DataFrame) -> dict:
    hla_exceptions = ['DRB1*03:03', 'DRB1*09:02', 'A*02:06'] #'DQA1*05:01', 'DQA1*02:01','A*02:06']
    epitope_locations = defaultdict(list)
    for i in tqdm(range(0, len(EpitopeDB))): #len(EpitopeDB)
        Epitope = EpitopeDB.iloc[i].Epitope
        loc = defaultdict(list)
        for hla in EpitopeDB.iloc[i]['Luminex Alleles']:
            if hla in hla_exceptions:
                logger.warning(f'Skipped hla: {hla}')
                continue
            locus, filename = hla_to_filename(hla)
            pdb_exist, pdb_path = find_molecule_path(locus, filename)
            if pdb_exist: 
                structure = load_hla_structure(hla, pdb_path)
                poly_residues = EpitopeDB.iloc[i].PolymorphicResidues
                try: 
                    loc[hla] =  get_location(poly_residues, structure, locus).\
                                get_array().\
                                round(2).\
                                tolist()
                except KeyError as e:
                    logger.error(f'Epitope {poly_residues} HLA {structure.get_id()} "KeyError" {e}')
        epitope_locations[Epitope].append(loc)
    return epitope_locations

def write_location_df(EpitopeDB:pd.DataFrame) -> pd.DataFrame:
    epitope_locations = find_locations(EpitopeDB)
    df = pd.DataFrame(epitope_locations)\
                               .T\
                               .reset_index()\
                               .rename(columns={'index':'Epitope', 0:'Location'})
    return EpitopeDB.merge(df, on='Epitope')



In [97]:
new = 'EpitopevsHLA_clean.xlsx'
EpitopeDB = load_epitopedb()


# Run the location script

In [98]:
# %%timeit
import logging
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
    
logger = logging.getLogger(__name__)
logging.basicConfig(filename= 'ep_distance.log',
                    filemode = 'w',
                    format= '%(name)s - %(levelname)s - %(message)s',
                    level=logging.DEBUG,
                   )


dictionary = find_locations(EpitopeDB)
# EpitopeDB_new = write_dsitance_df(EpitopeDB)


HBox(children=(FloatProgress(value=0.0, max=424.0), HTML(value='')))




In [100]:
df = pd.DataFrame(dictionary).T

In [71]:
df[0][0]['C*08:02'].get_array().round(2).tolist()

[17.51, 19.17, 13.44]

In [101]:
# df[0][0]
df

Unnamed: 0,0
1C,"{'C*08:02': [17.51, 19.17, 13.44], 'C*12:02': ..."
9D,"{'C*07:02': [16.85, -7.29, 12.52], 'B*08:01': ..."
9F[A],"{'A*32:01': [2.45, 4.33, 6.64], 'A*02:01': [20..."
9H,"{'B*40:01': [-3.53, 9.98, -29.58], 'B*37:01': ..."
9S,"{'A*23:01': [-24.06, 8.52, -14.82], 'C*04:01':..."
...,...
160A,"{'DQA1*03:01': [73.51, 67.51, 103.62], 'DQA1*0..."
160AD,"{'DQA1*03:01': [74.45, 66.88, 102.11], 'DQA1*0..."
160D,"{'DQA1*03:02': [72.38, 67.98, 103.95], 'DQA1*0..."
160S,"{'DQA1*05:03': [25.66, -20.38, -26.74]}"


In [102]:
# df[0][0]

In [103]:
df = pd.DataFrame(dictionary)\
                           .T\
                           .reset_index()\
                           .rename(columns={'index':'Epitope', 0:'Location'})
# df['Location_mean'] =  df['Location'].apply(lambda x: np.array([_.get_array()[1] for _ in x.items()]).mean())
# df['Location_std'] =  df['Location'].apply(lambda x: np.array([_[1] for _ in x.items()]).std())

In [127]:
df.Location[3]

defaultdict(list,
            {'B*40:01': [-3.53, 9.98, -29.58],
             'B*37:01': [-4.05, 8.66, -28.62],
             'B*41:01': [-3.53, 10.07, -29.6],
             'B*49:01': [22.06, 9.32, 26.37],
             'B*73:01': [4.23, 0.12, -0.95],
             'B*27:05': [4.0, -0.05, -1.1],
             'B*27:08': [4.31, -0.1, -1.26],
             'B*45:01': [-3.57, 9.96, -29.55],
             'B*40:06': [3.72, -9.57, -29.43],
             'B*40:02': [3.67, -9.82, -29.34],
             'B*18:01': [-3.93, 8.82, -28.52],
             'B*50:01': [3.71, -9.42, -29.57],
             'B*27:03': [4.13, 103.92, -0.92]})

In [116]:
# df['Location'].apply(lambda x: np.round(sum([np.array(_) for _ in x.values()])/len(x), 2))

0        [17.71, 19.09, 13.19]
1          [13.7, -7.55, 2.44]
2             [1.79, 7.0, 9.7]
3         [2.4, 10.15, -16.31]
4           [0.23, 8.52, 7.04]
                ...           
419       [49.02, 15.24, 21.2]
420        [50.12, 17.4, 27.9]
421      [43.99, 39.83, 59.98]
422    [25.66, -20.38, -26.74]
423       [44.98, 3.76, 39.71]
Name: Location, Length: 424, dtype: object