In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from lib import DihedralAdherence
from lib import PDBMineQuery
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
from pathlib import Path
PDBMINE_URL = os.getenv("PDBMINE_URL")
PROJECT_DIR = 'tests'
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
proteins = ['T1024', 'T1096', 'T1027', 'T1082', 'T1091', 'T1058', 'T1049', 'T1030', 'T1056', 'T1038', 'T1025', 'T1028']
da = DihedralAdherence(proteins[0], [4,5,6,7], PDBMINE_URL, PROJECT_DIR, kdews=[1,32,64,128])
seqs = da.xray_phi_psi.seq_ctxt.unique()

In [None]:
lengths = [4096, 512, 256, 256]
length = sum([l for l in lengths])
s = [sum(lengths[:i]) for i,l in enumerate(lengths)]
device = 'cuda:0'
class LSTMNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.h = 32
        h = self.h
        nl = 1
        p_drop = 0.0
        mlp_h = 20
        self.lstm1 = nn.LSTM(2, h, nl, batch_first=True, bidirectional=True, dropout=p_drop)
        self.lstm2 = nn.LSTM(2, h, nl, batch_first=True, bidirectional=True, dropout=p_drop)
        self.lstm3 = nn.LSTM(2, h, nl, batch_first=True, bidirectional=True, dropout=p_drop)
        self.lstm4 = nn.LSTM(2, h, nl, batch_first=True, bidirectional=True, dropout=p_drop)
        self.dropout1 = nn.Dropout(0.3)
        self.fc1 = nn.Linear(h*8+20, mlp_h)
        self.dropout2 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(mlp_h, 2)
    def forward(self, x, xres):
        x1 = x[:,:,s[0]:s[1]].permute(0,2,1)
        x2 = x[:,:,s[1]:s[2]].permute(0,2,1)
        x3 = x[:,:,s[2]:s[3]].permute(0,2,1)
        x4 = x[:,:,s[3]:    ].permute(0,2,1)
        x1 = self.lstm1(x1)[1][0].permute(1,0,2)[:,-2:,:].flatten(1)
        x2 = self.lstm2(x2)[1][0].permute(1,0,2)[:,-2:,:].flatten(1)
        x3 = self.lstm3(x3)[1][0].permute(1,0,2)[:,-2:,:].flatten(1)
        x4 = self.lstm4(x4)[1][0].permute(1,0,2)[:,-2:,:].flatten(1)
        x = torch.cat([x1,x2,x3,x4], dim=1)
        x = torch.cat([x, xres], dim=1)
        x = self.dropout1(x)
        x = self.fc1(F.relu(x))
        x = self.dropout2(x)
        x = self.fc2(F.relu(x))
        return x    
model = (LSTMNet()).to(device)
model.load_state_dict(torch.load('ml_data/best_model_xres_h32_nl1_mlp20_dropout30_1.7k.pt'))

In [None]:
def plot(Xp, y, i, logits=None, logits2=None, res=None):
    ls = lengths
    Xp = Xp.cpu().clone().detach()
    y = y
    Xp[Xp==0] = np.nan
    s = [sum(lengths[:i]) for i,l in enumerate(ls)]
    s = [sum(lengths[:i]) for i,l in enumerate(lengths)]
    plt.plot(Xp[i, 0, s[0]:s[1]], Xp[i, 1, s[0]:s[1]], 'o', label='4')
    plt.plot(Xp[i, 0, s[1]:s[2]], Xp[i, 1, s[1]:s[2]], 'o', label='5')
    plt.plot(Xp[i, 0, s[2]:s[3]], Xp[i, 1, s[2]:s[3]], 'o', label='6')
    plt.plot(Xp[i, 0, s[3]:    ], Xp[i, 1, s[3]:    ], 'o', label='7')
    
    plt.plot(y[i,0],y[i,1], 'X', label='true', color='purple',  markersize=10)
    if logits is not None:
        logits = logits.cpu().clone().detach()
        plt.plot(logits[i,0].detach(),logits[i,1].detach(), 'X', label='pred', color='black', markersize=10)
    if logits2 is not None:
        plt.plot(logits2[i,0],logits2[i,1], 'X', label='pred2', color='orange', markersize=10)
    plt.legend()

In [None]:
from lib.utils import get_phi_psi_dist, find_kdepeak
from lib.constants import AMINO_ACID_MAP
i = 4
phi_psi = get_phi_psi_dist(da.queries, seqs[i])[0]
y = da.xray_phi_psi[da.xray_phi_psi.seq_ctxt == seqs[i]][['phi','psi']].values
X = []
kde = find_kdepeak(phi_psi, None)[['phi','psi']].values.reshape(1,-1)
for weight, l in zip([1,32,64,128], lengths):
    phi, psi = phi_psi[phi_psi.weight == weight][['phi','psi']].values.T
    if phi.shape[0] < l:
        phi = np.pad(phi, ((0,l-phi.shape[0])), mode='constant', constant_values=0)
        psi = np.pad(psi, ((0,l-psi.shape[0])), mode='constant', constant_values=0)
    else:
        phi = np.random.choice(phi, l, replace=False)
        psi = np.random.choice(psi, l, replace=False)
    X.append(np.stack([phi, psi]))
xres = F.one_hot(torch.Tensor([AMINO_ACID_MAP[da.get_center(seqs[i])]]).to(torch.int64), 20).to(device)
X = np.hstack(X)
X = torch.Tensor(X).unsqueeze(0).to(device)
logits = model(X, xres)
plot(X, y, 0, logits, kde)