In [1]:
%load_ext autoreload
%autoreload 2

In [40]:
from lib import DihedralAdherence
from lib import PDBMineQuery, MultiWindowQuery
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from tabulate import tabulate
from collections import defaultdict
from dotenv import load_dotenv
import torch
from torch import nn
import torch.nn.functional as F
from scipy.stats import gaussian_kde
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, Dataset, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from lib.constants import AMINO_ACID_MAP, AMINO_ACID_MAP_INV
PDBMINE_URL = os.getenv("GREEN_PDBMINE_URL")
PROJECT_DIR = 'ml_data_new'

In [63]:
PDBMINE_URL = os.getenv("GREEN_PDBMINE_URL")
PROJECT_DIR = 'ml_data'
pdb_codes = [f.name.split('_')[0] for f in Path(PROJECT_DIR).iterdir() if f.is_dir()]
pid = pdb_codes[0]
winsizes = [7]
lengths = [4096, 512, 256, 256]
# da = MultiWindowQuery(pid, winsizes, PDBMINE_URL, PROJECT_DIR)
da = MultiWindowQuery('6t1z', winsizes, PDBMINE_URL, 'del')
da.compute_structure()
da.query_pdbmine()

Structure exists: 'pdb/pdb6t1z.ent' 
Computing phi-psi for xray
Computing phi-psi for alphafold
Querying PDBMine - 7


  0%|          | 0/5 [00:00<?, ?it/s]

{'status': 'Running', 'queryID': '0dc43abe-3989-11ef-ba2c-0242ac110002'}
Waiting
Waiting
Waiting


 20%|██        | 1/5 [01:45<07:00, 105.03s/it]

Received matches - 0
{'status': 'Running', 'queryID': '4c5f311e-3989-11ef-ba2c-0242ac110002'}
Waiting
Waiting


 40%|████      | 2/5 [03:15<04:48, 96.21s/it] 

Received matches - 1
{'status': 'Running', 'queryID': '820a1d11-3989-11ef-ba2c-0242ac110002'}
Waiting
Waiting


 60%|██████    | 3/5 [04:45<03:06, 93.46s/it]

Received matches - 2
{'status': 'Running', 'queryID': 'b7cc02f2-3989-11ef-ba2c-0242ac110002'}
Waiting
Waiting


 80%|████████  | 4/5 [06:15<01:32, 92.11s/it]

Received matches - 3
{'status': 'Running', 'queryID': 'ed765a65-3989-11ef-ba2c-0242ac110002'}


100%|██████████| 5/5 [07:15<00:00, 87.07s/it]

Received matches - 4





In [84]:
seqs = pd.merge(
    da.xray_phi_psi[['seq_ctxt', 'res', 'phi', 'psi']], 
    da.af_phi_psi[['seq_ctxt', 'phi', 'psi']], 
    on='seq_ctxt', suffixes=('', '_af')
).rename(columns={'seq_ctxt': 'seq'})

X = []
y = []
x_res = []
af_phi_psi = []
lengths = [4096, 512, 256, 256]
lengths_dict = {w:l for w,l in zip(winsizes, lengths)}
for i,row in tqdm(seqs.iterrows()):
    phis = []
    psis = []
    if np.isnan(row.phi) or np.isnan(row.psi) or np.isnan(row.phi_af) or np.isnan(row.psi_af):
        print('nan')
        continue
    for q in da.queries:
        inner_seq = q.get_subseq(row.seq)
        matches = q.results[q.results.seq == inner_seq][['seq', 'phi', 'psi']]
        if matches.shape[0] == 0:
            phis.append(np.zeros(lengths_dict[q.winsize]))
            psis.append(np.zeros(lengths_dict[q.winsize]))
            print('no matches', q.winsize)
            continue
        phi = matches.phi.values
        psi = matches.psi.values
        if matches.shape[0] >= lengths_dict[q.winsize]:
            phi = np.random.choice(phi, lengths_dict[q.winsize], replace=False)
            psi = np.random.choice(psi, lengths_dict[q.winsize], replace=False)
        else:
            phi = np.pad(phi, (0, lengths_dict[q.winsize] - matches.shape[0]))
            psi = np.pad(psi, (0, lengths_dict[q.winsize] - matches.shape[0]))
        phis.append(phi)
        psis.append(psi)
    phis = np.concatenate(phis)
    psis = np.concatenate(psis)
    if np.sum(phis) == 0: # no matches
        print('no matches')
        continue
    X.append(np.stack([phis, psis]))
    y.append(np.array([row.phi, row.psi]))
    x_res.append(AMINO_ACID_MAP[row.res])
    af_phi_psi.append([row.phi_af, row.psi_af])
if len(X) == 0:
    print('No matches')
X = np.stack(X)
y = np.stack(y)
x_res = F.one_hot(torch.Tensor(x_res).to(torch.int64), num_classes=20)
af_phi_psi = np.stack(af_phi_psi)

0it [00:00, ?it/s]

373it [00:00, 3772.10it/s]

no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
(4096,) (4096,)
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches
no matches 7
no matches





In [79]:
y.shape

(20, 2)

In [64]:
da.queries[-1].results.groupby('seq').size().describe()
da.queries[-1].results.groupby('seq').size().sort_values()[::-1]

seq
LEVLFQG    210
ASILAGL     50
VILLLTV     49
MGANIAT     36
VLLMTTL     27
QTLGADL     24
LLTVLVS     24
LTVLVSF     24
NLEVLFQ     22
LLVILLL     20
LTTTFTP     16
LGAALAI     13
ILGAALG     11
LGAALGA     11
IKAIGVS      7
AALAIAS      3
NGVAAIK      2
YYNQYLG      2
VLVAVNR      1
LAGLLVS      1
ILLLTVL      1
AGLLVSI      1
dtype: int64

In [44]:
PDBMINE_URL = os.getenv("GREEN_PDBMINE_URL")
import requests

In [54]:
response = requests.post(
    PDBMINE_URL + '/v1/api/query',
    json={
        "residueChain": 'AGLLVSI',
        "codeLength": 1,
        "windowSize": 7
    }
)
qid = response.json()['queryID']

In [55]:
response = requests.get(PDBMINE_URL + f'/v1/api/query/{qid}')
response

<Response [200]>

In [56]:
frames = response.json()['frames']
matches = frames[next(iter(frames.keys()))]
total_matches = sum([len(v) for k,v in matches.items()])
n_matches = sum([len(v) for k,v in matches.items() if k.split('_')[0] != pid])
total_matches, n_matches

(2, 2)