In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from lib import DihedralAdherence
from lib import PDBMineQuery
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
from pathlib import Path
import json
from matplotlib.patches import ConnectionPatch
from sklearn.cluster import KMeans, DBSCAN, MeanShift, HDBSCAN
from sklearn.metrics import silhouette_score, silhouette_samples
from collections import defaultdict
from tqdm import tqdm
from scipy.linalg import inv

PDBMINE_URL = os.getenv("PDBMINE_URL")
PROJECT_DIR = 'casp_da'

In [3]:
proteins = [
  'T1024', 'T1030', 'T1030-D2', 'T1024-D1', 'T1032-D1', 'T1053-D1', 'T1027-D1', 'T1029-D1',
  'T1025-D1', 'T1028-D1', 'T1030-D1', 'T1053-D2', 'T1057-D1','T1058-D1', 'T1058-D2'
]
da = DihedralAdherence(proteins[5], [4,5,6,7], PDBMINE_URL, PROJECT_DIR, kdews=[1,1,1,1], 
                      mode='full_window', weights_file='ml_runs/best_model-kde_16-32_383.pt', device='cpu')
                    #   mode='ml', weights_file='ml_runs/best_model-kde_16-32_383.pt', device='cpu')

da.load_results_da()
center_idx_ctxt = da.queries[-1].get_center_idx()
winsize_ctxt = da.queries[-1].winsize
if center_idx_ctxt < 0:
    center_idx_ctxt = winsize_ctxt + center_idx_ctxt
da.seqs_for_window = da.seqs[center_idx_ctxt:-(winsize_ctxt - center_idx_ctxt - 1)]

Initializing T1053-D1 ...
Results already exist
Casp ID: T1053-D1 	PDB: 7m7a
Structure exists: 'pdb/pdb7m7a.ent' 
UniProt ID: Q5ZRA8


In [4]:
def diff(x1, x2):
    d = np.abs(x1 - x2)
    return np.minimum(d, 360-d)

def get_phi_psi_dist(q, seq_ctxt):
    seq = q.get_subseq(seq_ctxt)
    phi_psi_dist = q.results_window[q.results_window.seq == seq]
    phi_psi_dist = phi_psi_dist[['match_id', 'window_pos', 'phi', 'psi']].pivot(index='match_id', columns='window_pos', values=['phi', 'psi'])
    phi_psi_dist.columns = [f'{c[0]}_{c[1]}' for c in phi_psi_dist.columns.to_flat_index()]
    phi_psi_dist = phi_psi_dist.dropna(axis=0)
    return phi_psi_dist

def get_xrays(ins, q, seq_ctxt, return_df=False):
    center_idx = q.get_center_idx_pos()
    xray_pos = ins.xray_phi_psi[ins.xray_phi_psi.seq_ctxt == seq_ctxt].pos.iloc[0]
    xrays = ins.xray_phi_psi[(ins.xray_phi_psi.pos >= xray_pos-center_idx) & (ins.xray_phi_psi.pos < xray_pos-center_idx+q.winsize)].copy()
    xray_point = np.concatenate([xrays['phi'].values, xrays['psi'].values])
    if return_df:
        return xray_point, xrays
    return xray_point

def get_afs(ins, q, seq_ctxt, return_df=False):
    center_idx = q.get_center_idx_pos()
    af_pos = ins.af_phi_psi[ins.af_phi_psi.seq_ctxt == seq_ctxt].pos.iloc[0]
    afs = ins.af_phi_psi[(ins.af_phi_psi.pos >= af_pos-center_idx) & (ins.af_phi_psi.pos < af_pos-center_idx+q.winsize)].copy()
    af_point = np.concatenate([afs['phi'].values, afs['psi'].values])
    if return_df:
        return af_point, afs
    return af_point

def get_preds(ins, q, seq_ctxt):
    center_idx = q.get_center_idx_pos()
    pred_pos = ins.phi_psi_predictions[ins.phi_psi_predictions.seq_ctxt == seq_ctxt].pos.unique()
    if len(pred_pos) == 0:
        print(f"No predictions for {seq_ctxt}")
    if len(pred_pos) > 1:
        print(f"Multiple predictions for {seq_ctxt}")
        raise ValueError
    pred_pos = pred_pos[0]
    preds = ins.phi_psi_predictions[(ins.phi_psi_predictions.pos >= pred_pos-center_idx) & (ins.phi_psi_predictions.pos < pred_pos-center_idx+q.winsize)].copy()
    preds = preds[['protein_id', 'pos', 'phi', 'psi']].pivot(index='protein_id', columns='pos', values=['phi', 'psi'])
    preds.columns = [f'{c[0]}_{c[1]-pred_pos+center_idx}' for c in preds.columns.to_flat_index()]
    preds = preds.dropna(axis=0)
    return preds

def calc_xray_score(phi_psi_dist, xrays, q, precomputed_dists):
    # Distance to nearest cluster average
    d = np.linalg.norm(diff(xrays[np.newaxis,:], phi_psi_dist.iloc[:,:q.winsize*2].values), axis=1)
    d = pd.DataFrame({'d': d, 'c': phi_psi_dist.cluster})
    nearest_cluster = d.groupby('c').d.mean().idxmin()
    cluster_points = phi_psi_dist[phi_psi_dist.cluster == nearest_cluster].iloc[:,:q.winsize*2].values
    # cluster_avg = cluster_points.mean(axis=0)
    cluster_medoid = get_cluster_medoid(phi_psi_dist, precomputed_dists, nearest_cluster, q)
    
    xray_dist = np.sqrt((diff(xrays, cluster_medoid)**2).sum())

    return xray_dist, nearest_cluster

def estimate_icov(q, phi_psi_dist_c, cluster_medoid):
    # estimate covariance matrix
    cluster_points = phi_psi_dist_c.iloc[:,:q.winsize*2].values
    diffs = diff(cluster_points, cluster_medoid)

    # cov = []
    # for diffi in diffs:
    #     diffi = diffi.reshape(-1, 1)
    #     cov.append(diffi @ diffi.T)
    # cov = np.array(cov).sum(axis=0) / (diffs.shape[0] - 1)
    cov = (diffs[...,np.newaxis] @ diffs[:,np.newaxis]).sum(axis=0) / (diffs.shape[0] - 1)
    cov = cov + np.eye(cov.shape[0]) * 1e-6 # add small value to diagonal to avoid singular matrix
    if np.any(cov <= 0):
        print("Non-positive covariance matrix")
        return None
    if np.any(cov.diagonal() < 1):
        print("Covariance matrix less than 1")
        return None
    eigenvalues, eigenvectors = np.linalg.eig(cov)
    if np.any(eigenvalues < 0):
        print("Negative eigenvalues - non-positive semi-definite covariance matrix")
        return None
    icov = inv(cov)
    return icov

def get_target_cluster(q, phi_psi_dist, point):
    d = np.linalg.norm(diff(point[np.newaxis,:], phi_psi_dist.iloc[:,:q.winsize*2].values), axis=1)
    d = pd.DataFrame({'d': d, 'c': phi_psi_dist.cluster})
    nearest_cluster = d.groupby('c').d.mean().idxmin()
    return nearest_cluster

def calc_maha_xray(phi_psi_dist, xrays, q, precomputed_dists, af):
    target_cluster = get_target_cluster(q, phi_psi_dist, af)
    cluster_medoid = get_cluster_medoid(phi_psi_dist, precomputed_dists, target_cluster, q)
    icov = estimate_icov(q, phi_psi_dist[phi_psi_dist.cluster == target_cluster], cluster_medoid)
    if icov is None:
        return None, target_cluster    

    # xray_maha_
    xray_diff = diff(xrays, cluster_medoid)
    xray_maha = np.sqrt(xray_diff @ icov @ xray_diff)

    return xray_maha, target_cluster

def calc_maha_preds(phi_psi_dist, preds, q, precomputed_dists, af):
    target_cluster = get_target_cluster(q, phi_psi_dist, af)
    cluster_medoid = get_cluster_medoid(phi_psi_dist, precomputed_dists, target_cluster, q)
    icov = estimate_icov(q, phi_psi_dist[phi_psi_dist.cluster == target_cluster], cluster_medoid)
    if icov is None:
        return None

    # Distance from preds to target
    preds_diff = diff(preds.iloc[:,:q.winsize*2].values, cluster_medoid)
    preds_maha = np.sqrt((preds_diff @ icov @ preds_diff.T).diagonal())
    return preds_maha

def calc_score(q, preds, phi_psi_dist, intracluster_dists, xrays=None, afs=None):
    # Distance to nearest cluster average
    d = np.linalg.norm(diff(preds.iloc[:,:q.winsize*2].values[:,np.newaxis], phi_psi_dist.iloc[:,:q.winsize*2].values), axis=2)
    average_dists = []
    clusters = phi_psi_dist.cluster.unique()
    for c in clusters:
        average_dists.append(d[:,phi_psi_dist.cluster == c].mean(axis=1))
    average_dists = np.array(average_dists).T
    min_dists = average_dists.min(axis=1)
    nearest_clusters = average_dists.argmin(axis=1)

    preds_dist = pd.DataFrame(index=preds.index, columns=['dist'])
    preds_dist['dist'] = preds_dist['dist'].astype(float)
    for c in clusters:
        cluster_points = phi_psi_dist[phi_psi_dist.cluster == c].iloc[:,:q.winsize*2].values
        cluster_avg = cluster_points.mean(axis=0) # TODO: use medoid
        preds_c = preds[nearest_clusters == c].iloc[:,:q.winsize*2].values
        preds_dists_c = np.linalg.norm(diff(preds_c, cluster_avg), axis=1)
        preds_dist.loc[nearest_clusters == c, 'dist'] = preds_dists_c

    return preds_dist.dist.values

# def calc_intra_cluster(phi_psi_dist, precomputed_dists):
#     ds = {}
#     for c in phi_psi_dist.cluster.unique():
#         d = precomputed_dists[phi_psi_dist.cluster == c][:,phi_psi_dist.cluster == c]
#         ds[c] = d.sum() / (d.shape[0] * (d.shape[0]-1))
#     return ds

def get_cluster_medoid(phi_psi_dist, precomputed_dists, c, q):
    d = precomputed_dists[phi_psi_dist.cluster == c][:,phi_psi_dist.cluster == c]
    return phi_psi_dist[phi_psi_dist.cluster == c].iloc[d.sum(axis=1).argmin(), :q.winsize*2].values

def precompute_dists(phi_psi_dist):
    def diff(x1, x2):
            d = np.abs(x1 - x2)
            return np.minimum(d, 360-d)
    precomputed_dists = np.linalg.norm(diff(phi_psi_dist.values[:,np.newaxis], phi_psi_dist.values), axis=2)
    return precomputed_dists
    
# def assign_clusters(phi_psi_dist, precomputed_dists, eps=75):
#     phi_psi_dist['cluster'] = DBSCAN(eps=eps, min_samples=5, metric='precomputed').fit(precomputed_dists).labels_
#     n_clusters = len(phi_psi_dist.cluster.unique())
#     return n_clusters

def assign_clusters(phi_psi_dist, precomputed_dists):
    precomputed_dists = precomputed_dists.copy()
    # phi_psi_dist['cluster'] = HDBSCAN(min_cluster_size=20, min_samples=5, metric='precomputed').fit(precomputed_dists).labels_
    phi_psi_dist['cluster'] = HDBSCAN(
        min_cluster_size=20, 
        # min_samples=5, 
        metric='precomputed', 
        allow_single_cluster=True,
        cluster_selection_epsilon=30
    ).fit(precomputed_dists).labels_
    n_clusters = len(phi_psi_dist.cluster.unique())
    return n_clusters - 1

def plot(q, phi_psi_dist, xrays=None, c=None):
    fig, axes = plt.subplots(1,q.winsize, figsize=(q.winsize*4,5))
    if xrays is not None:
        xrays = xrays.reshape(2, -1)
    phi_psi_dist_points = phi_psi_dist.iloc[:,:q.winsize*2].values.reshape(phi_psi_dist.shape[0], 2, -1)
    for i in range(q.winsize):
        axes[i].scatter(phi_psi_dist_points[:,0,i], phi_psi_dist_points[:,1,i], marker='.')
        if c is not None:
            axes[i].scatter(phi_psi_dist_points[phi_psi_dist.cluster==c,0,i], phi_psi_dist_points[phi_psi_dist.cluster==c,1,i], c='orange', zorder=5)
        if xrays is not None:
            axes[i].scatter(xrays[0,i], xrays[1,i], c='r', marker='X', zorder=10, s=100)
        axes[i].set_xlim(-180,180)
        axes[i].set_ylim(-180,180)

In [5]:
class PDBMineWindow():
    def __init__(self, q, seq_ctxt, index):
        self.seq_ctxt = seq_ctxt
        self.index = index
        self.medoids = None
        self.q = q
        self.find_medoids()
    
    def find_medoids(self):
        phi_psi_dist = get_phi_psi_dist(self.q, self.seq_ctxt)
        precomputed_dists = precompute_dists(phi_psi_dist)
        n_clusters = assign_clusters(phi_psi_dist, precomputed_dists)
        self._medoids = np.array([get_cluster_medoid(
            phi_psi_dist, precomputed_dists, c, self.q
        ) for c in range(n_clusters) if c != -1])
        self.medoids = [Medoid(i, self._medoids) for i in range(len(self._medoids))]
    
    def extend_and_spawn_threads(self, other, threshold=30):
        # Distnace from each medoid of this window to each medoid of other window
        range1 = np.concatenate([np.arange(q.winsize-1), np.arange(q.winsize, q.winsize*2-1)])
        range2 = np.concatenate([np.arange(1, q.winsize), np.arange(q.winsize+1, q.winsize*2)])
        dists = np.linalg.norm(diff(self._medoids[:,np.newaxis,range1], other._medoids[:,range2]), axis=2)

        # Extend clusters with close matches
        closest = dists.argmin(axis=1)
        min_dists = dists.min(axis=1)
        for i,medoid in enumerate(self.medoids):
            if min_dists[i] < threshold:
                # extend thread
                if other.medoids[closest[i]].thread is None:
                    medoid.thread.length += 1
                    other.medoids[closest[i]].thread = medoid.thread
                # other medoid already assigned to a thread, merge them
                else:
                    medoid.thread = other.medoids[closest[i]].thread
                    # TODO backtrack and reassign all medoids in this thread to the new thread
                other.medoids[closest[i]].assign_thread(medoid.thread)

        # Spawn new threads from any medoids in other window that were not matched
        for i,medoid in enumerate(other.medoids):
            if medoid.thread is None:
                # spawn new thread
                medoid.thread = Thread()
    
    def __repr__(self):
        return f"PDBMineWindow({self.seq_ctxt}, {self.index})={len(self.medoids)}"

In [6]:
class Medoid():
    def __init__(self, index, medoids):
        self.index = index # index of this medoid in the window
        self.thread = None # pointer to medoid thread this belonds to
        self.medoids = medoids # pointer to all medoids of this window
    
    @property
    def medoid(self):
        return self.medoids[self.index]
    
    def assign_thread(self, thread):
        self.thread = thread
    
    def __repr__(self):
        return f"Medoid[{self.thread}]({self.medoids[self.index]})"

In [7]:
class MedoidThreads():
    def __init__(self, initial_window):
        self.window = initial_window # last window of live threads
        self.threads = []
        for medoid in self.window.medoids:
            self.threads.append(Thread())
            medoid.assign_thread(self.threads[-1])


In [8]:
class Thread():
    _id = 0
    def __init__(self):
        # self.medoid = initial_medoid # last medoid of thread
        self.length = 1
        self.id = Thread._id
        self.alive = True
        Thread._id += 1
    
    def __repr__(self):
        return f"Thread[{self.id}]={self.length}"

In [37]:
windows = []
for i,seq_ctxt in enumerate(da.seqs_for_window):
    window = PDBMineWindow(da.queries[0], seq_ctxt, i)
    
    # intialize medoid threads
    if i == 0:
        for medoid in window.medoids:
            medoid.assign_thread(Thread())
    # Extend existing threads or create new ones
    else:
        windows[-1].extend_and_spawn_threads(window)

    windows.append(window)
    # plot(da.queries[0], phi_psi_dist, c=phi_psi_dist.groupby('cluster').size().idxmax())
    if i == 1:
        break

In [38]:
windows[0].medoids

[Medoid[Thread[71]=1]([-70.4 -81.1 -97.9 -64.1 110.1 -32.8 163.4 -41.6]),
 Medoid[Thread[72]=2]([-64.8 -65.2 -67.3  56.9 -41.2 -30.9 -13.7  37.6]),
 Medoid[Thread[73]=1]([ -71.2  -85.8 -143.8 -134.6  127.   -50.8  155.2  126.2]),
 Medoid[Thread[74]=1]([-62.4 -76.6  67.  -87.  -15.7 -34.3  40.5 148.3]),
 Medoid[Thread[75]=1]([-72.1 -69.8 -36.9 -58.8 161.3 149.6 -30.6 -48.5]),
 Medoid[Thread[76]=2]([-67.  -83.2 -67.6 -58.1 -26.3 -12.5 148.1 134.9]),
 Medoid[Thread[77]=1]([-69.7 -76.1 -66.7 -59.5 -23.9 147.8 176.3 -31.2]),
 Medoid[Thread[78]=1]([ -87.4 -105.5  -96.7  -54.4   -4.6  144.4  133.6  -41.5]),
 Medoid[Thread[79]=1]([-137.4  -71.9 -145.  -122.2  153.2  127.7  102.6   -6.1]),
 Medoid[Thread[80]=1]([-115.9 -148.3  -58.6 -153.6  125.8  161.3  113.2  155. ]),
 Medoid[Thread[81]=1]([ -92.  -110.6 -157.1  -70.6  117.9  145.6  124.3   -4.9]),
 Medoid[Thread[82]=2]([-76.6 -64.7 -61.  -65.9 155.6 -33.1 -29.5 -20.3]),
 Medoid[Thread[83]=1]([-56.4 -59.9 -74.1 -77.  154.3 -30.8 -21.1 101. ])

In [41]:
[f'{i}: {m}' for i,m in enumerate(windows[1].medoids)]

['0: Medoid[Thread[92]=1]([ -81.    68.6  -85.  -167.3  -35.6   45.   156.   175.3])',
 '1: Medoid[Thread[93]=1]([ -79.7   70.2 -161.2 -167.3   -7.8  -40.8  136.5  167.8])',
 '2: Medoid[Thread[94]=1]([-63.2 -65.8  57.7 -72.8 -42.4 -17.2  33.2 161.2])',
 '3: Medoid[Thread[95]=1]([-82.  -82.2 -55.2  54.5 -33.9 -15.8 133.7  23.1])',
 '4: Medoid[Thread[96]=1]([ -96.2 -141.9 -127.4 -163.8  -48.1  150.7  149.   151.5])',
 '5: Medoid[Thread[97]=1]([-108.6  -65.4  -72.6  -84.9  133.5  -29.8   -3.    -7.9])',
 '6: Medoid[Thread[98]=1]([ -89.2  -80.7 -152.9 -156.7  144.3  -31.1  142.1  155.8])',
 '7: Medoid[Thread[99]=1]([ -88.5 -110.4  -59.5  -67.9   -2.1  151.3  -29.9   -5.6])',
 '8: Medoid[Thread[82]=2]([-91.4 -97.2 -53.4 -66.  150.8 156.3 -43.6 -22.5])',
 '9: Medoid[Thread[100]=1]([-100.5 -106.3  -66.7 -126.1  141.   165.4  -22.2  123.4])',
 '10: Medoid[Thread[101]=1]([-65.3 -75.1 -92.8 -67.7 -25.7 -13.3 168.7 147.4])',
 '11: Medoid[Thread[102]=1]([-131.2 -138.   -85.7 -138.5  138.9  116.2  

In [28]:
q = da.queries[0]
range1 = np.concatenate([np.arange(q.winsize-1), np.arange(q.winsize, q.winsize*2-1)])
range2 = np.concatenate([np.arange(1, q.winsize), np.arange(q.winsize+1, q.winsize*2)])

In [40]:
np.linalg.norm(diff(windows[0]._medoids[:,np.newaxis,range1], windows[1]._medoids[:,range2]), axis=2).argmin(axis=1)[1]

14