In [1]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
import os
import random
import seaborn as sns
from scipy.stats import binom_test
from sklearn.metrics import confusion_matrix
from sklearn.metrics import matthews_corrcoef

from Bio import pairwise2
from Bio import Entrez
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq

from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

from AttentionModules import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ModuleNotFoundError: No module named 'Bio'

In [2]:
datDir_ligands = './dat/network_input/'
train_ligs = pd.read_csv(os.path.join(datDir_ligands,'train.txt'),sep='\t',names=['Peptide','Allele'])
train_ligs = removeAlleles(train_ligs)
test_ligs = pd.read_csv(os.path.join(datDir_ligands,'test.txt'),sep='\t',names=['Peptide','Allele'])
test_ligs = removeAlleles(test_ligs)
datDir_embedding = './dat/embedding/'

df_embedding = pd.read_csv(os.path.join(datDir_embedding,'aa_embedding_window5_dim100.txt'),header=None,comment='#')
alpha = df_embedding.loc[:,0]

trainDatNum = 10
testDatNum = 10

allAlleles = list(set(train_ligs['Allele'].values))
datAlleles = [group[0] for group in train_ligs.groupby('Allele') if len(group[1])>100]#Include all alleles with more than 100 ligands associated

datDir_seqs = './dat/MHC_seqs/'
allele2SeqDict_full = allele2SeqDictWrapper(datDir_seqs,'X_prot_align.fasta',datAlleles)
allele2SeqDict = allele2SeqDictWrapper2(datDir_seqs,'MHC_pseudo.dat',datAlleles)

allele2SeqDict_inv = {val:key for key, val in allele2SeqDict.items()}


FileNotFoundError: File b'./dat/network_input_uniq_allele_All/train_EL1.txt' does not exist

In [None]:
emb_dim = df_embedding.values.shape[1]-1
tokens_df = pd.DataFrame([['SOS']+[0.0]*emb_dim,['EOS']+[1.0]*emb_dim,['-']+[-1.0]*emb_dim])
df_embedding_token = pd.concat([tokens_df,df_embedding])
embeddingTensor = torch.tensor(df_embedding_token.loc[:,1:].values)
vocab_size=embeddingTensor.size()[0]
EMBEDDING_DIM=embeddingTensor.size()[1]
embeddingTensor

In [None]:
SOS_token = 0
EOS_token = 1

AA2IDX = makeAA2IDX(df_embedding_token)
IDX2AA = {val:key for key,val in AA2IDX.items()}

train_ligs_select_sample = selectXnumOfClass(train_ligs,'Allele',datAlleles,trainDatNum)
pairs = ligAllele2ligReceptor(train_ligs_select_sample)
pairs = pairs.values

pairs = pairs[:,[1,0]]#Reverse column order, peptide targets
        
test_ligs_select_sample = selectXnumOfClass(test_ligs,'Allele',datAlleles,testDatNum)
pairs2 = ligAllele2ligReceptor(test_ligs_select_sample)
pairs2 = pairs2.values

peptRec_train = ligAllele2ligReceptor(train_ligs)
MAX_LENGTH=max(list(map(lambda x: max(list(map(len,x))), peptRec_train.values)))+1

In [None]:
from Bio.PDB import *

datDir_structure = './dat/struct/'

parser = PDBParser(os.path.join(datDir_structure,'1i7r.pdb'))

structure = parser.get_structure('MHCI',os.path.join(datDir_structure,'1i7r.pdb'))
for model in structure:
    for chain in model:
        if chain.id=='A':
            mhcResidues = list(chain.get_residues())
            mhcResidues = [res for res in mhcResidues if res.id[1]<=181]
        if chain.id=='C':
            peptResidues = list(chain.get_residues())
            peptResidues = [res for res in peptResidues if res.id[0]==' ']

distMat = np.zeros((len(peptResidues),len(mhcResidues)))
for i,AAp in enumerate(peptResidues):
    for j,AAm in enumerate(mhcResidues):
        #print(AAp['CA']-AAm['CA'])
        distMat[i,j] = AAp['CA']-AAm['CA']

prox = [7,9,24,45,59,62,63,66,67,69,70,73,74,76,77,80,81,84,95,97,99,114,116,118,143,147,150,152,156,158,159,163,167,171]
prox = list(map(lambda x:x-1,prox))
distMat_proxy = distMat[:,prox]
plt.imshow(distMat_proxy)
        
plt.matshow(distMat)
plt.colorbar()
plt.show()

In [None]:
teacher_forcing_ratio = 1.0
hidden_size = 20
encoder2 = EncoderRNN2(vocab_size, EMBEDDING_DIM, hidden_size,dropout=dropout_p).to(device)
attn_decoder2 = BahdanauAttnDecoderRNN(hidden_size, EMBEDDING_DIM,vocab_size,dropout_p=dropout_p).to(device)
plot_losses,trainIdent,testIdent,corr = trainIters2(encoder2, attn_decoder2,1200, print_every=10,plot_every=10)

In [None]:

plotEvery=10
numEpochs = 1000
epochX = np.linspace(plotEvery,numEpochs,(numEpochs/float(plotEvery)))

col1='xkcd:royal blue'
col2='xkcd:light orange'
col3='xkcd:green'

fig, ax1 = plt.subplots()
t = epochX
s1 = plot_losses
ax1.plot(t, s1, col1,label='Train')
ax1.set_xlabel('Epochs',fontsize=20)
# Make the y-axis label, ticks and tick labels match the line color.
ax1.set_ylabel('NLLoss', color=col1,fontsize=20)
ax1.tick_params('y', colors=col1,labelsize=16)
ax1.tick_params('x', colors='k',labelsize=16)

ax2 = ax1.twinx()
s2 = trainIdent
s3=testIdent
ax2.plot(t, s3, col3,label='Test')
ax2.plot(t, s2, col2,label='Test')
ax2.set_ylabel('%Seq-Identify', color=col2,fontsize=20)
ax2.tick_params('y', colors=col2,labelsize=16)


patch1 = mpatches.Patch(color=col1, label='Train')
patch2 = mpatches.Patch(color=col2, label='Train')
patch3 = mpatches.Patch(color=col3, label='Test')
fig.legend(handles=[patch1,patch2,patch3],ncol=1,loc=[0.64,0.4],fontsize=14)

#ax1.legend()
#ax2.legend()

fig.tight_layout()
#plt.savefig(os.path.join('./','attention_trainingCurve.eps'))
plt.show()

In [None]:
output_words,attentions_fig = evaluate2(encoder2, attn_decoder2,'YHTEYREICAKTYENTAYLNYHDYTWAVLAYEWY')
atts_fig_USE = attentions_fig.numpy()[:-1,:-1]
atts_fig_USE=torch.stack(atts).mean(dim=0).numpy()[:-1,:-1]
print(attentions_fig.shape)
fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(211)
cax = ax.matshow(distMat_proxy, cmap='Blues')
#fig.colorbar(cax)
#ax.colorbar()
plt.tick_params(axis='both', which='major', labelsize=16)
ax = fig.add_subplot(212)
cax = ax.matshow(atts_fig_USE, cmap='Blues')
ax.set_yticklabels([''] + list(output_words) +
                       ['<EOS>'], rotation=90)
ax.set_xticklabels([''] + list('YDSEYRNIFTNTDESNLYLSYNYYTWAVDAYTWY'))
# Show label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.tick_params(axis='both', which='major', labelsize=16)

# Set up axes
plt.savefig(os.path.join('./','matrix_comparison.eps'))

plt.show()