In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
from lib import DihedralAdherence
from lib import PDBMineQuery, MultiWindowQuery
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from tabulate import tabulate
from collections import defaultdict
from dotenv import load_dotenv
import torch
from torch import nn
import torch.nn.functional as F
from scipy.stats import gaussian_kde
from sklearn.model_selection import train_test_split
import pickle
from torch.utils.data import TensorDataset, DataLoader, Dataset, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from lib.constants import AMINO_ACID_MAP, AMINO_ACID_MAP_INV
from lib.across_window_utils import (
    get_phi_psi_dist_window, get_afs_window, get_xrays_window, get_cluster_medoid, find_clusters,
    precompute_dists, filter_precomputed_dists, 
)
from collections import defaultdict
PDBMINE_URL = os.getenv("PDBMINE_URL")
PROJECT_DIR = 'ml_data'

In [5]:
from matplotlib.patches import ConnectionPatch

def plot(q, seq_ctxt, xrays, afs, clusters, phi_psi_dist, precomputed_dists):
    n_cluster_plot = 10
    n_clusters = len(np.unique(clusters))
    xrays = xrays.reshape(2, -1)
    afs = afs.reshape(2, -1)
    print(pd.Series(clusters).value_counts())

    cluster_points = phi_psi_dist.groupby(clusters).count().sort_values('phi_0', ascending=False).index.values
    clusters_plot = cluster_points[:n_cluster_plot]
    medoids = []
    for cluster in cluster_points:
        medoid = get_cluster_medoid(phi_psi_dist, precomputed_dists, clusters, cluster)
        medoids.append(medoid)
    medoids = np.array(medoids)

    colors = sns.color_palette('Dark2', n_clusters)
    fig, axes = plt.subplots(len(clusters_plot), q.winsize, figsize=(16, min(n_cluster_plot, len(clusters_plot))*4), sharey=True, sharex=True)
    if axes.ndim == 1:
        axes = axes.reshape(1, -1)
    for i,axrow in enumerate(axes):
        for j, ax in enumerate(axrow):
            cluster_dist = phi_psi_dist[clusters == clusters_plot[i]]

            sns.scatterplot(data=phi_psi_dist[clusters != clusters_plot[i]], x=f'phi_{j}', y=f'psi_{j}', ax=ax, label='Other Clusters', color='tab:blue', alpha=0.5)
            sns.scatterplot(data=cluster_dist, x=f'phi_{j}', y=f'psi_{j}', ax=ax, label=f'Cluster {clusters_plot[i]}', color=colors[i])
            ax.scatter(xrays[0,j], xrays[1,j], color='tab:red', marker='X', label='X-ray', zorder=1000)
            ax.scatter(afs[0,j], afs[1,j], color='tab:orange', marker='X', label='AF', zorder=1000)
            # ax.scatter(pred[0,j], pred[1,j], color='tab:orange', marker='X', label=pred_id, zorder=1000)
            ax.scatter(medoids[i].reshape(2,-1)[0,j], medoids[i].reshape(2,-1)[1,j], color='black', marker='X', label='Cluster Centroid', zorder=1000)

            def add_conn(xyA, xyB, color, lw, **kwargs):
                con = ConnectionPatch(
                    xyA=xyA, 
                    xyB=xyB, 
                    coordsA="data", coordsB="data", 
                    axesA=axrow[j], axesB=axrow[j+1], 
                    color=color, lw=lw, linestyle='--', alpha=0.5, **kwargs
                )
                fig.add_artist(con)
            if j < q.winsize - 1:
                # TODO draw lines for 50 points closest to centroid
                for k, row in cluster_dist.sample(min(cluster_dist.shape[0], 50)).iterrows():
                    add_conn((row[f'phi_{j}'], row[f'psi_{j}']), (row[f'phi_{j+1}'], row[f'psi_{j+1}']), colors[i], 1)
                add_conn((xrays[0,j], xrays[1,j]), (xrays[0,j+1], xrays[1,j+1]), 'tab:red', 5, zorder=100)
                add_conn((afs[0,j], afs[1,j]), (afs[0,j+1], afs[1,j+1]), 'tab:orange', 5, zorder=100)
                # add_conn((pred[0,j], pred[1,j]), (pred[0,j+1], pred[1,j+1]), 'tab:orange', 5, zorder=100)
                add_conn((medoids[i].reshape(2,-1)[0,j], medoids[i].reshape(2,-1)[1,j]), (medoids[i].reshape(2,-1)[0,j+1], medoids[i].reshape(2,-1)[1,j+1]), 'black', 5, zorder=100)

            ax.set_xlim(-180, 180)
            ax.set_ylim(-180, 180)
            ax.set_xlabel('')
            if j == q.winsize - 1:
                ax.legend()
            else:
                ax.legend().remove()
            if j == 0:
                ax.set_ylabel(f'Cluster {clusters_plot[i]} [{cluster_dist.shape[0]}]')
    fig.supxlabel('Phi')
    fig.supylabel('Psi')
    # fig.suptitle(
    #     # f'Clustered Phi/Psi Distributions for {seq_ctxt} in protein {da.casp_protein_id}: N={n_points} Silhouette Score: {sil_score:.2f}, X-ray Score [Cluster {nearest_cluster}]: {xray_sil:.2f}, Prediction Score [Cluster {nearest_cluster_pred}]: {pred_sil:.2f}', 
    #     f'Clustered Phi/Psi Distributions for {seq_ctxt} in protein {da.casp_protein_id}: N={n_points} ({n_unassigned} unassigned) Silhouette Score: {sil_score:.2f}, X-ray Score [Cluster {nearest_cluster}]: {xray_maha:.2f}', 
    #     y=1.01
    # )
    plt.tight_layout()
    plt.show()

In [9]:
import json
pdb_codes = json.load(open('proteins.json'))
ml_data = [f.name.split('_')[0] for f in Path('ml_data').iterdir()]
for pdb_code in pdb_codes[::-1]:
    if pdb_code in ml_data:
        print(pdb_code)
        break
pdb_codes.index(pdb_code)

5NUP


646

In [11]:
pdb_codes = [f.name.split('_')[0] for f in Path(PROJECT_DIR).iterdir() if f.is_dir()]
ml_samples = [f.stem for f in Path('ml_samples/medoids').iterdir()]
for pdb_code in pdb_codes[::-1]:
    if pdb_code in ml_samples:
        print(pdb_code)
        break
pdb_codes.index(pdb_code)

1FTG


592

In [13]:
pdb_codes = [f.name.split('_')[0] for f in Path(PROJECT_DIR).iterdir() if f.is_dir()]
winsizes = [4,5,6,7]
outdir = Path(f'ml_samples/medoids')
outdir.mkdir(exist_ok=True, parents=True)
X_lens = [15, 5, 3, 2]

for id in pdb_codes[593:]:
    if (outdir / f'{id}.pt').exists():
        print('Skipping', id)
        continue
    try:
        da = MultiWindowQuery(id, winsizes, PDBMINE_URL, PROJECT_DIR)
        da.load_results()
    except FileNotFoundError as e:
        print(e)
        continue
    if da.af_phi_psi is None:
        continue
    print(id)

    center_idx_ctxt = da.queries[-1].get_center_idx_pos()
    winsize_ctxt = da.queries[-1].winsize
    seqs_for_window = da.seqs[center_idx_ctxt:-(winsize_ctxt - center_idx_ctxt - 1)]
    seqs_for_window = pd.DataFrame({'seq_ctxt': seqs_for_window})

    seqs = pd.merge(
        seqs_for_window,
        da.af_phi_psi[['seq_ctxt']], 
        on='seq_ctxt'
    ).rename(columns={'seq_ctxt': 'seq'})
    if seqs.shape[0] == 0:
        print('No sequences for', id)
        continue
    print(seqs.shape, seqs.seq.nunique())

    x_medoids = defaultdict(list)
    x_af = defaultdict(list)
    x_res = []
    y = []

    for i,row in tqdm(seqs.iterrows()):
        # Check if alphafold data is complete for largest window size
        afs = get_afs_window(da, da.queries[-1], row.seq)
        if (afs is None) or (afs.shape[0] != da.queries[-1].winsize*2) or (np.isnan(afs).any()):
            # print(f"AF data for {row.seq} is incomplete")
            continue
        # Check if xrays are complete for largest window size
        xrays = get_xrays_window(da, da.queries[-1], row.seq)
        if xrays.shape[0] != da.queries[-1].winsize*2 or np.isnan(xrays).any():
            # print(f"Xray data for {row.seq} is incomplete")
            continue
        for j,q in enumerate(da.queries):
            xrays = get_xrays_window(da, q, row.seq)
            afs = get_afs_window(da, q, row.seq)
            phi_psi_dist = get_phi_psi_dist_window(q, row.seq)
            skip = False
            # if xrays.shape[0] != q.winsize*2 or np.isnan(xrays).any():
                # print(f"Xray data for {row.seq} is incomplete")
                # skip = True
            # if (afs is None) or (afs.shape[0] != q.winsize*2) or (np.isnan(afs).any()):
                # print(f"AF data for {row.seq} is incomplete")
                # skip = True

            phi_psi_dist = phi_psi_dist.dropna()
            phi_psi_dist = phi_psi_dist[(phi_psi_dist <= 180).all(axis=1)]
            
            if phi_psi_dist.shape[0] == 0:
                # print(f"No pdbmine data for {row.seq}")
                skip = True
            if phi_psi_dist.shape[1] != q.winsize*2:
                # print(f"Phi/Psi data for {row.seq} is incomplete")
                skip = True
            if phi_psi_dist.shape[0] > 10000:
                phi_psi_dist = phi_psi_dist.sample(10000)
            
            medoids = np.zeros([X_lens[j], q.winsize*2])

            if not skip and phi_psi_dist.shape[0] == 1:
                medoids[0] = phi_psi_dist.iloc[0].values
            elif not skip and phi_psi_dist.shape[0] > 1:
                # Cluster
                dists = precompute_dists(phi_psi_dist)
                n_clusters, clusters = find_clusters(dists, min_cluster_size=np.min([phi_psi_dist.shape[0], 20]), cluster_selection_epsilon=30)

                if n_clusters == 0:
                    n_clusters, clusters = find_clusters(dists, min_cluster_size=2, cluster_selection_epsilon=60)
                
                if n_clusters == 0:
                    n_clusters, clusters = find_clusters(dists, min_cluster_size=2, cluster_selection_epsilon=120)
                
                if n_clusters > 0:
                    dists, phi_psi_dist, clusters = filter_precomputed_dists(dists, phi_psi_dist, clusters)
                    cluster_counts = pd.Series(clusters).value_counts().sort_values(ascending=False)
                    for k,cluster in zip(range(X_lens[j]), cluster_counts.index):
                        medoid = get_cluster_medoid(phi_psi_dist, dists, clusters, cluster)
                        medoids[k] = medoid

            x_medoids[j].append(torch.tensor(medoids, dtype=torch.float32))
            x_af[j].append(torch.tensor(afs, dtype=torch.float32))
        x_res.append(AMINO_ACID_MAP[row.seq[center_idx_ctxt]])
        y.append(torch.tensor(xrays.reshape(2, -1)[:, center_idx_ctxt], dtype=torch.float32))
        if torch.isnan(y[-1]).any():
            print('Xray data is nan for', row.seq)
    if len(y) == 0:
        print('No data for', id)
        continue
    for i in range(len(da.queries)):
        x_medoids[i] = torch.stack(x_medoids[i])
        x_af[i] = torch.stack(x_af[i])

    x_res = F.one_hot(torch.tensor(x_res).to(torch.int64), num_classes=20).float()
    y = torch.stack(y)
    torch.save((list(x_medoids.values()), (list(x_af.values())), x_res, y), outdir / f'{id}.pt')

Skipping 1FTG
Results already exist
Structure exists: 'pdb/pdb6cj0.ent' 
UniProt ID: B3G2E1
6CJ0
(237, 1) 237


237it [02:09,  1.83it/s]


In [200]:
next(iter(Path('ml_samples/medoids').iterdir()))

PosixPath('ml_samples/medoids/3DE6.pt')

In [202]:
x_medoids, x_af, x_res, y = torch.load('ml_samples/medoids/3DE6.pt')

In [209]:
x_medoids[3].shape

torch.Size([260, 2, 14])

In [210]:
for i,m in enumerate(x_medoids[1]):
    if m.shape[1] != 10:
        print(i,m.shape, m)

In [125]:
X, af, xres, y = torch.load(outdir / f'{id}.pt')

In [20]:
np.array(n_clusters_all[4]).mean(), np.array(n_clusters_all[4]).std()
# np.array(n_clusters_all[7]).mean(), np.array(n_clusters_all[7]).std()
average_clusters = {k: np.array(v).mean() for k,v in n_clusters_all.items()}
std_clusters = {k: np.array(v).std() for k,v in n_clusters_all.items()}

# avg:
# 4: 12.675570539419088
# 5: 2.2909803921568628
# 6: 1.2014057853473912
# 7: 1.0849134377576257

# std:
# 4: 7.717139186664446
# 5: 1.4648062017775114
# 6: 0.5645546946235267
# 7: 0.3251658412701237