In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from lib import DihedralAdherence
from lib import PDBMineQuery
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
from pathlib import Path
import json
from matplotlib.patches import ConnectionPatch
from sklearn.cluster import KMeans, DBSCAN, MeanShift, HDBSCAN
from sklearn.metrics import silhouette_score, silhouette_samples
from collections import defaultdict
from tqdm import tqdm
from scipy.linalg import inv

PDBMINE_URL = os.getenv("PDBMINE_URL")
PROJECT_DIR = 'casp_da'

In [4]:
proteins = [
  'T1024', 'T1030', 'T1030-D2', 'T1024-D1', 'T1032-D1', 'T1053-D1', 'T1027-D1', 'T1029-D1',
  'T1025-D1', 'T1028-D1', 'T1030-D1', 'T1053-D2', 'T1057-D1','T1058-D1', 'T1058-D2'
]
da = DihedralAdherence(proteins[0], [4,5,6,7], PDBMINE_URL, PROJECT_DIR, kdews=[1,1,1,1], 
                      mode='full_window', weights_file='ml_runs/best_model-kde_16-32_383.pt', device='cpu')
                    #   mode='ml', weights_file='ml_runs/best_model-kde_16-32_383.pt', device='cpu')

da.load_results_da()
center_idx_ctxt = da.queries[-1].get_center_idx()
winsize_ctxt = da.queries[-1].winsize
if center_idx_ctxt < 0:
    center_idx_ctxt = winsize_ctxt + center_idx_ctxt
da.seqs_for_window = da.seqs[center_idx_ctxt:-(winsize_ctxt - center_idx_ctxt - 1)]

Initializing T1024 ...
Results already exist
Casp ID: T1024 	PDB: 6t1z
Structure exists: 'pdb/pdb6t1z.ent' 
UniProt ID: Q48658


In [8]:
def diff(x1, x2):
    d = np.abs(x1 - x2)
    return np.minimum(d, 360-d)

def get_phi_psi_dist(q, seq_ctxt):
    seq = q.get_subseq(seq_ctxt)
    phi_psi_dist = q.results_window[q.results_window.seq == seq]
    phi_psi_dist = phi_psi_dist[['match_id', 'window_pos', 'phi', 'psi']].pivot(index='match_id', columns='window_pos', values=['phi', 'psi'])
    phi_psi_dist.columns = [f'{c[0]}_{c[1]}' for c in phi_psi_dist.columns.to_flat_index()]
    phi_psi_dist = phi_psi_dist.dropna(axis=0)
    return phi_psi_dist

def get_xrays(ins, q, seq_ctxt, return_df=False):
    center_idx = q.get_center_idx_pos()
    xray_pos = ins.xray_phi_psi[ins.xray_phi_psi.seq_ctxt == seq_ctxt].pos.iloc[0]
    xrays = ins.xray_phi_psi[(ins.xray_phi_psi.pos >= xray_pos-center_idx) & (ins.xray_phi_psi.pos < xray_pos-center_idx+q.winsize)].copy()
    xray_point = np.concatenate([xrays['phi'].values, xrays['psi'].values])
    if return_df:
        return xray_point, xrays
    return xray_point

def get_afs(ins, q, seq_ctxt, return_df=False):
    center_idx = q.get_center_idx_pos()
    af_pos = ins.af_phi_psi[ins.af_phi_psi.seq_ctxt == seq_ctxt].pos.iloc[0]
    afs = ins.af_phi_psi[(ins.af_phi_psi.pos >= af_pos-center_idx) & (ins.af_phi_psi.pos < af_pos-center_idx+q.winsize)].copy()
    af_point = np.concatenate([afs['phi'].values, afs['psi'].values])
    if return_df:
        return af_point, afs
    return af_point

def get_preds(ins, q, seq_ctxt):
    center_idx = q.get_center_idx_pos()
    pred_pos = ins.phi_psi_predictions[ins.phi_psi_predictions.seq_ctxt == seq_ctxt].pos.unique()
    if len(pred_pos) == 0:
        print(f"No predictions for {seq_ctxt}")
    if len(pred_pos) > 1:
        print(f"Multiple predictions for {seq_ctxt}")
        raise ValueError
    pred_pos = pred_pos[0]
    preds = ins.phi_psi_predictions[(ins.phi_psi_predictions.pos >= pred_pos-center_idx) & (ins.phi_psi_predictions.pos < pred_pos-center_idx+q.winsize)].copy()
    preds = preds[['protein_id', 'pos', 'phi', 'psi']].pivot(index='protein_id', columns='pos', values=['phi', 'psi'])
    preds.columns = [f'{c[0]}_{c[1]-pred_pos+center_idx}' for c in preds.columns.to_flat_index()]
    preds = preds.dropna(axis=0)
    return preds

# def calc_xray_score(phi_psi_dist, xrays, q, precomputed_dists):
#     # Distance to nearest cluster average
#     d = np.linalg.norm(diff(xrays[np.newaxis,:], phi_psi_dist.iloc[:,:q.winsize*2].values), axis=1)
#     d = pd.DataFrame({'d': d, 'c': phi_psi_dist.cluster})
#     nearest_cluster = d.groupby('c').d.mean().idxmin()
#     cluster_points = phi_psi_dist[phi_psi_dist.cluster == nearest_cluster].iloc[:,:q.winsize*2].values
#     # cluster_avg = cluster_points.mean(axis=0)
#     cluster_medoid = get_cluster_medoid(phi_psi_dist, precomputed_dists, nearest_cluster, q)
    
#     xray_dist = np.sqrt((diff(xrays, cluster_medoid)**2).sum())

#     return xray_dist, nearest_cluster

def estimate_icov(phi_psi_dist_c, cluster_medoid):
    # estimate covariance matrix
    cluster_points = phi_psi_dist_c.values
    diffs = diff(cluster_points, cluster_medoid)

    # cov = []
    # for diffi in diffs:
    #     diffi = diffi.reshape(-1, 1)
    #     cov.append(diffi @ diffi.T)
    # cov = np.array(cov).sum(axis=0) / (diffs.shape[0] - 1)
    cov = (diffs[...,np.newaxis] @ diffs[:,np.newaxis]).sum(axis=0) / (diffs.shape[0] - 1)
    cov = cov + np.eye(cov.shape[0]) * 1e-6 # add small value to diagonal to avoid singular matrix
    if np.any(cov <= 0):
        print("Non-positive covariance matrix")
        return None
    if np.any(cov.diagonal() < 1):
        print("Covariance matrix less than 1")
        return None
    eigenvalues, eigenvectors = np.linalg.eig(cov)
    if np.any(eigenvalues < 0):
        print("Negative eigenvalues - non-positive semi-definite covariance matrix")
        return None
    icov = inv(cov)
    return icov

def get_target_cluster(phi_psi_dist, clusters, point):
    d = np.linalg.norm(diff(point[np.newaxis,:], phi_psi_dist.values), axis=1)
    d = pd.DataFrame({'d': d, 'c': clusters})
    nearest_cluster = d.groupby('c').d.mean().idxmin()
    return nearest_cluster

def get_cluster_medoid(phi_psi_dist, precomputed_dists, clusters, c):
    d = precomputed_dists[clusters == c][:,clusters == c]
    return phi_psi_dist[clusters == c].iloc[d.sum(axis=1).argmin()].values

def calc_maha_xray(phi_psi_dist, xrays, precomputed_dists, clusters, af):
    target_cluster = get_target_cluster(phi_psi_dist, clusters, af)
    cluster_medoid = get_cluster_medoid(phi_psi_dist, precomputed_dists, clusters, target_cluster)
    icov = estimate_icov(phi_psi_dist[clusters == target_cluster], cluster_medoid)
    if icov is None:
        return None, target_cluster    

    # xray_maha_
    xray_diff = diff(xrays, cluster_medoid)
    xray_maha = np.sqrt(xray_diff @ icov @ xray_diff)

    return xray_maha, target_cluster

def calc_maha_preds(phi_psi_dist, preds, precomputed_dists, clusters, af):
    target_cluster = get_target_cluster(phi_psi_dist, clusters, af)
    cluster_medoid = get_cluster_medoid(phi_psi_dist, precomputed_dists, clusters, target_cluster)
    icov = estimate_icov(phi_psi_dist[clusters == target_cluster], cluster_medoid)
    if icov is None:
        return None

    # Distance from preds to target
    preds_diff = diff(preds.values, cluster_medoid)
    preds_maha = np.sqrt((preds_diff @ icov @ preds_diff.T).diagonal())
    return preds_maha

# def calc_score(q, preds, phi_psi_dist, intracluster_dists, xrays=None, afs=None):
#     # Distance to nearest cluster average
#     d = np.linalg.norm(diff(preds.iloc[:,:q.winsize*2].values[:,np.newaxis], phi_psi_dist.iloc[:,:q.winsize*2].values), axis=2)
#     average_dists = []
#     clusters = phi_psi_dist.cluster.unique()
#     for c in clusters:
#         average_dists.append(d[:,phi_psi_dist.cluster == c].mean(axis=1))
#     average_dists = np.array(average_dists).T
#     min_dists = average_dists.min(axis=1)
#     nearest_clusters = average_dists.argmin(axis=1)

#     preds_dist = pd.DataFrame(index=preds.index, columns=['dist'])
#     preds_dist['dist'] = preds_dist['dist'].astype(float)
#     for c in clusters:
#         cluster_points = phi_psi_dist[phi_psi_dist.cluster == c].iloc[:,:q.winsize*2].values
#         cluster_avg = cluster_points.mean(axis=0) # TODO: use medoid
#         preds_c = preds[nearest_clusters == c].iloc[:,:q.winsize*2].values
#         preds_dists_c = np.linalg.norm(diff(preds_c, cluster_avg), axis=1)
#         preds_dist.loc[nearest_clusters == c, 'dist'] = preds_dists_c

#     return preds_dist.dist.values

# def calc_intra_cluster(phi_psi_dist, precomputed_dists):
#     ds = {}
#     for c in phi_psi_dist.cluster.unique():
#         d = precomputed_dists[phi_psi_dist.cluster == c][:,phi_psi_dist.cluster == c]
#         ds[c] = d.sum() / (d.shape[0] * (d.shape[0]-1))
#     return ds

def precompute_dists(phi_psi_dist):
    def diff(x1, x2):
            d = np.abs(x1 - x2)
            return np.minimum(d, 360-d)
    precomputed_dists = np.linalg.norm(diff(phi_psi_dist.values[:,np.newaxis], phi_psi_dist.values), axis=2)
    return precomputed_dists

def filter_precomputed_dists(precomputed_dists, phi_psi_dist, clusters):
    return(
        precomputed_dists[clusters != -1][:,clusters != -1],
        phi_psi_dist[clusters != -1],
        clusters[clusters != -1]
    )
    
# def assign_clusters(phi_psi_dist, precomputed_dists, eps=75):
#     phi_psi_dist['cluster'] = DBSCAN(eps=eps, min_samples=5, metric='precomputed').fit(precomputed_dists).labels_
#     n_clusters = len(phi_psi_dist.cluster.unique())
#     return n_clusters

def find_clusters(phi_psi_dist, precomputed_dists, min_cluster_size=20):
    precomputed_dists = precomputed_dists.copy()
    # phi_psi_dist['cluster'] = HDBSCAN(min_cluster_size=20, min_samples=5, metric='precomputed').fit(precomputed_dists).labels_
    clusters = HDBSCAN(
        min_cluster_size=min_cluster_size, 
        # min_samples=5, 
        metric='precomputed', 
        allow_single_cluster=True,
        cluster_selection_epsilon=30
    ).fit(precomputed_dists).labels_
    n_clusters = len(np.unique(clusters))
    return n_clusters - 1, clusters

def plot(q, phi_psi_dist, xrays=None, c=None):
    fig, axes = plt.subplots(1,q.winsize, figsize=(q.winsize*4,5))
    if xrays is not None:
        xrays = xrays.reshape(2, -1)
    phi_psi_dist_points = phi_psi_dist.iloc[:,:q.winsize*2].values.reshape(phi_psi_dist.shape[0], 2, -1)
    for i in range(q.winsize):
        axes[i].scatter(phi_psi_dist_points[:,0,i], phi_psi_dist_points[:,1,i], marker='.')
        if c is not None:
            axes[i].scatter(phi_psi_dist_points[phi_psi_dist.cluster==c,0,i], phi_psi_dist_points[phi_psi_dist.cluster==c,1,i], c='orange', zorder=5)
        if xrays is not None:
            axes[i].scatter(xrays[0,i], xrays[1,i], c='r', marker='X', zorder=10, s=100)
        axes[i].set_xlim(-180,180)
        axes[i].set_ylim(-180,180)

In [11]:
# Distance from xray and preds to cluster nearest to alphafold
ins = da
q = ins.queries[0]
col_name = 'af_score'
ins.phi_psi_predictions[col_name] = np.nan
ins.xray_phi_psi[col_name] = np.nan

for seq_ctxt in tqdm(ins.seqs_for_window):
    phi_psi_dist = get_phi_psi_dist(q, seq_ctxt)
    xrays = get_xrays(ins, q, seq_ctxt)
    preds = get_preds(ins, q, seq_ctxt)
    afs = get_afs(ins, q, seq_ctxt)

    if xrays.shape[0] != q.winsize*2:
        print(f"Xray data for {seq_ctxt} is incomplete")
        continue
    if preds.shape[0] == 0:
        print(f"No predictions for {seq_ctxt}")
        continue
    if phi_psi_dist.shape[0] == 0:
        print(f"No pdbmine data for {seq_ctxt}")
        continue
    if phi_psi_dist.shape[0] < 100:
        print(f"Not enough pdbmine data for {seq_ctxt}")
        continue
    if afs.shape[0] != q.winsize*2:
        print(f"AF data for {seq_ctxt} is incomplete")
        continue

    precomputed_dists = precompute_dists(phi_psi_dist.iloc[:,:q.winsize*2])
    n_clusters, clusters = find_clusters(phi_psi_dist, precomputed_dists)
    
    precomputed_dists, phi_psi_dist, clusters = filter_precomputed_dists(precomputed_dists, phi_psi_dist, clusters)

    xray_maha, c = calc_maha_xray(phi_psi_dist, xrays, precomputed_dists, clusters, afs)
    if xray_maha is None:
        print(f"Error calculating mahalanobis distance for {seq_ctxt}")
        print(f"Cluster size: {len(phi_psi_dist[phi_psi_dist.cluster == c])}")
        xray_maha = np.nan

    # Distance from preds to xray
    preds_maha = calc_maha_preds(phi_psi_dist, preds, precomputed_dists, clusters, afs)
    
    ins.xray_phi_psi.loc[ins.xray_phi_psi.seq_ctxt == seq_ctxt, col_name] = xray_maha

    view = ins.phi_psi_predictions.loc[ins.phi_psi_predictions.seq_ctxt == seq_ctxt].reset_index().set_index('protein_id')
    view.loc[preds.index, col_name] = preds_maha
    ins.phi_psi_predictions.loc[view['index'], col_name] = view.set_index('index')[col_name]

 50%|█████     | 184/367 [04:06<00:52,  3.47it/s]

Xray data for MTETFKP is incomplete
Xray data for TETFKPT is incomplete


 50%|█████     | 185/367 [04:06<00:46,  3.89it/s]

Xray data for NIFQAYK is incomplete


 54%|█████▍    | 199/367 [04:12<00:38,  4.37it/s]

Not enough pdbmine data for TYMIFMG


 62%|██████▏   | 226/367 [04:22<00:38,  3.68it/s]

Xray data for SNSFKTI is incomplete
Xray data for NSFKTIT is incomplete


 62%|██████▏   | 228/367 [04:22<00:30,  4.54it/s]

Xray data for YGQRMLT is incomplete


100%|██████████| 367/367 [06:53<00:00,  1.13s/it]


In [None]:
seq_ctxt = 'GIVFLGA'
q = da.queries[0]
# seq_ctxt = da.seqs_for_window[0]
phi_psi_dist = get_phi_psi_dist(da.queries[0], seq_ctxt)
xrays = get_xrays(da, da.queries[0], seq_ctxt)
precomputed_dists = precompute_dists(phi_psi_dist.iloc[:,:q.winsize*2])
n_clusters = assign_clusters(phi_psi_dist, precomputed_dists)
print(n_clusters)
precomputed_dists = precomputed_dists[phi_psi_dist.cluster != -1][:,phi_psi_dist.cluster != -1]
phi_psi_dist = phi_psi_dist[phi_psi_dist.cluster != -1]
xray_maha, nearest_cluster = calc_maha_xray(phi_psi_dist, xrays, q, precomputed_dists)

print(phi_psi_dist.shape, (phi_psi_dist.cluster == nearest_cluster).sum())
plot(da.queries[0], phi_psi_dist, xrays, c=nearest_cluster)

In [None]:
# Distance from xray to nearest cluster
ins = da
q = ins.queries[0]
col_name = 'xray_near_score'
ins.xray_phi_psi[col_name] = np.nan
ins.xray_phi_psi['new_score'] = np.nan
chosen_cluster_sizes = []

for seq_ctxt in tqdm(ins.seqs_for_window):
    phi_psi_dist = get_phi_psi_dist(q, seq_ctxt)
    xrays = get_xrays(ins, q, seq_ctxt)

    if xrays.shape[0] != q.winsize*2:
        print(f"Xray data for {seq_ctxt} is incomplete")
        continue
    if phi_psi_dist.shape[0] == 0:
        print(f"No pdbmine data for {seq_ctxt}")
        continue
    if phi_psi_dist.shape[0] < 100:
        print(f"Not enough pdbmine data for {seq_ctxt}")
        continue

    precomputed_dists = precompute_dists(phi_psi_dist.iloc[:,:q.winsize*2])
    n_clusters = assign_clusters(phi_psi_dist, precomputed_dists)

    precomputed_dists = precomputed_dists[phi_psi_dist.cluster != -1][:,phi_psi_dist.cluster != -1]
    phi_psi_dist = phi_psi_dist[phi_psi_dist.cluster != -1]

    xray_maha, c = calc_maha_xray(phi_psi_dist, xrays, q, precomputed_dists, xrays)
    chosen_cluster_sizes.append(len(phi_psi_dist[phi_psi_dist.cluster == c]))
    if xray_maha is None:
        print(f"Error calculating mahalanobis distance for {seq_ctxt}")
        print(f"Cluster size: {len(phi_psi_dist[phi_psi_dist.cluster == c])}")
        xray_maha = np.nan

    xray_score, _ = calc_xray_score(phi_psi_dist, xrays, q, precomputed_dists)
    ins.xray_phi_psi.loc[ins.xray_phi_psi.seq_ctxt == seq_ctxt, col_name] = xray_maha
    ins.xray_phi_psi.loc[ins.xray_phi_psi.seq_ctxt == seq_ctxt, 'new_score'] = xray_score

In [None]:
# Distance from preds to xray itself
ins = da
q = ins.queries[0]
col_name = 'to_xray_score'
ins.phi_psi_predictions[col_name] = np.nan
ins.phi_psi_predictions['new_score'] = np.nan

for seq_ctxt in tqdm(ins.seqs_for_window):
    phi_psi_dist = get_phi_psi_dist(q, seq_ctxt)
    xrays = get_xrays(ins, q, seq_ctxt)
    preds = get_preds(ins, q, seq_ctxt)

    if xrays.shape[0] != q.winsize*2:
        print(f"Xray data for {seq_ctxt} is incomplete")
        continue
    if preds.shape[0] == 0:
        print(f"No predictions for {seq_ctxt}")
        continue
    if phi_psi_dist.shape[0] == 0:
        print(f"No pdbmine data for {seq_ctxt}")
        continue
    if phi_psi_dist.shape[0] < 100:
        print(f"Not enough pdbmine data for {seq_ctxt}")
        continue

    precomputed_dists = precompute_dists(phi_psi_dist.iloc[:,:q.winsize*2])
    n_clusters = assign_clusters(phi_psi_dist, precomputed_dists)
    
    precomputed_dists = precomputed_dists[phi_psi_dist.cluster != -1][:,phi_psi_dist.cluster != -1]
    phi_psi_dist = phi_psi_dist[phi_psi_dist.cluster != -1]

    # Distance from preds to xray itself
    target_cluster = get_target_cluster(q, phi_psi_dist, xrays)
    cluster_medoid = get_cluster_medoid(phi_psi_dist, precomputed_dists, target_cluster, q)
    icov = estimate_icov(q, phi_psi_dist[phi_psi_dist.cluster == target_cluster], cluster_medoid)
    if icov is None:
        continue
    preds_diff = diff(preds.iloc[:,:q.winsize*2].values, xrays)

    preds_maha = np.sqrt((preds_diff @ icov @ preds_diff.T).diagonal())
    preds_dist = np.linalg.norm(preds_diff, axis=1)

    preds_np = preds.iloc[:,:q.winsize*2].values
    # preds_cossim = (preds_np @ xrays) / (np.linalg.norm(preds_np, axis=1) * np.linalg.norm(xrays))
    
    view = ins.phi_psi_predictions.loc[ins.phi_psi_predictions.seq_ctxt == seq_ctxt].reset_index().set_index('protein_id')
    view.loc[preds.index, col_name] = preds_maha
    view.loc[preds.index, 'new_score'] = preds_dist
    ins.phi_psi_predictions.loc[view['index'], col_name] = view.set_index('index')[col_name]
    ins.phi_psi_predictions.loc[view['index'], 'new_score'] = view.set_index('index')['new_score']

In [None]:
# Distance from xray and preds to cluster nearest to alphafold
ins = da
q = ins.queries[0]
col_name = 'af_score'
ins.phi_psi_predictions[col_name] = np.nan
ins.xray_phi_psi[col_name] = np.nan

for seq_ctxt in tqdm(ins.seqs_for_window):
    phi_psi_dist = get_phi_psi_dist(q, seq_ctxt)
    xrays = get_xrays(ins, q, seq_ctxt)
    preds = get_preds(ins, q, seq_ctxt)
    afs = get_afs(ins, q, seq_ctxt)

    if xrays.shape[0] != q.winsize*2:
        print(f"Xray data for {seq_ctxt} is incomplete")
        continue
    if preds.shape[0] == 0:
        print(f"No predictions for {seq_ctxt}")
        continue
    if phi_psi_dist.shape[0] == 0:
        print(f"No pdbmine data for {seq_ctxt}")
        continue
    if phi_psi_dist.shape[0] < 100:
        print(f"Not enough pdbmine data for {seq_ctxt}")
        continue
    if afs.shape[0] != q.winsize*2:
        print(f"AF data for {seq_ctxt} is incomplete")
        continue

    precomputed_dists = precompute_dists(phi_psi_dist.iloc[:,:q.winsize*2])
    n_clusters = assign_clusters(phi_psi_dist, precomputed_dists)
    
    precomputed_dists = precomputed_dists[phi_psi_dist.cluster != -1][:,phi_psi_dist.cluster != -1]
    phi_psi_dist = phi_psi_dist[phi_psi_dist.cluster != -1]

    xray_maha, c = calc_maha_xray(phi_psi_dist, xrays, q, precomputed_dists, afs)
    if xray_maha is None:
        print(f"Error calculating mahalanobis distance for {seq_ctxt}")
        print(f"Cluster size: {len(phi_psi_dist[phi_psi_dist.cluster == c])}")
        xray_maha = np.nan

    # Distance from preds to xray
    preds_maha = calc_maha_preds(phi_psi_dist, preds, q, precomputed_dists, afs)
    
    ins.xray_phi_psi.loc[ins.xray_phi_psi.seq_ctxt == seq_ctxt, col_name] = xray_maha

    view = ins.phi_psi_predictions.loc[ins.phi_psi_predictions.seq_ctxt == seq_ctxt].reset_index().set_index('protein_id')
    view.loc[preds.index, col_name] = preds_maha
    ins.phi_psi_predictions.loc[view['index'], col_name] = view.set_index('index')[col_name]

In [None]:
sns.lineplot(ins.xray_phi_psi, x='pos', y=col_name, label='Mahalanobis distance')
plt.xlabel('residue')
plt.ylabel('distance')
plt.title('Mahalanobis distance from xray to cluster nearest alphafold')
plt.show()

In [None]:
sns.kdeplot(da.xray_phi_psi[col_name], label='X-ray score [Maha to nearest cluster]')
# sns.kdeplot(da.xray_phi_psi['new_score'], label='X-ray score [nearest cluster]')
plt.xlabel('Silhouette score')
plt.ylabel('Density')
plt.title('Xray score distribution')
plt.legend()

In [None]:
sns.lineplot(da.phi_psi_predictions, x='pos', y=col_name, label='Maha to xray')
# sns.lineplot(da.phi_psi_predictions, x='pos', y='new_score', label='Distance to xray')
plt.xlabel('Residue')
plt.ylabel('Distance')
plt.title(f'Average Score for Predictions [Maha to cluster nearest af]')
plt.legend()

In [None]:
sns.kdeplot(da.phi_psi_predictions[col_name], label='Mahalanobis Distnace')
sns.kdeplot(da.phi_psi_predictions['new_score'], label='Distance')
plt.xlabel('Silhouette score')
plt.ylabel('Density')
plt.title('Distance to X-ray distribution (predictions for T1024)')
plt.legend()

In [None]:
grouped_preds = da.phi_psi_predictions.pivot(index='protein_id', columns='pos', values=col_name)
grouped_preds = grouped_preds.loc[da.grouped_preds.sort_values('GDT_TS').protein_id.values]
grouped_preds_na = grouped_preds.isna().sum(axis=1)
grouped_preds = grouped_preds[grouped_preds_na < grouped_preds_na.quantile(0.8)]
sns.heatmap(grouped_preds.values.astype(float))
plt.title('Sil score heatmap for predictions (rows sorted by GDT_TS descending)')

In [None]:
grouped_preds = da.phi_psi_predictions.groupby('protein_id').sum(numeric_only=True)
da.grouped_preds.set_index('protein_id', inplace=True)
da.grouped_preds.loc[grouped_preds.index, col_name] = grouped_preds[col_name]
da.grouped_preds.reset_index(inplace=True)
sns.regplot(x=col_name, y='GDT_TS', data=da.grouped_preds)
plt.ylabel('GDT_TS')
plt.xlabel('Sum Maha score')

from scipy.stats import pearsonr, linregress

print(pearsonr(da.grouped_preds[col_name], da.grouped_preds.GDT_TS))
regr = linregress(da.grouped_preds[col_name], da.grouped_preds.GDT_TS)
regr.rvalue**2

plt.title(f'GDT_TS vs Sum Mahalanobis score for predictions of T1024 [$R^2$={regr.rvalue**2:.3f}]')

In [None]:
from scipy.stats import pearsonr, linregress

print(pearsonr(da.grouped_preds[col_name], da.grouped_preds.GDT_TS))
regr = linregress(da.grouped_preds[col_name], da.grouped_preds.GDT_TS)
regr.rvalue**2

In [None]:
da.queries[0].results.groupby('seq').size().min()

In [5]:
def assign_clusters_dbscan(phi_psi_dist, precomputed_dists, eps):
    phi_psi_dist['cluster'] = DBSCAN(eps=eps, min_samples=5, metric='precomputed').fit(precomputed_dists).labels_
    n_clusters = len(phi_psi_dist.cluster.unique())
    return n_clusters -1

def assign_clusters(phi_psi_dist, precomputed_dists, cluster_selection_epsilon):
    precomputed_dists = precomputed_dists.copy()
    # phi_psi_dist['cluster'] = HDBSCAN(min_cluster_size=20, min_samples=5, metric='precomputed').fit(precomputed_dists).labels_
    phi_psi_dist['cluster'] = HDBSCAN(
        min_cluster_size=20, 
        # min_samples=5, 
        metric='precomputed', 
        allow_single_cluster=True,
        cluster_selection_epsilon=cluster_selection_epsilon
    ).fit(precomputed_dists).labels_
    n_clusters = len(phi_psi_dist.cluster.unique())
    return n_clusters -1

In [7]:
# dbscan_results = []
hdbscan_results = []
q = da.queries[0]

for i, seq_ctxt in enumerate(tqdm(da.seqs_for_window)):
    phi_psi_dist = get_phi_psi_dist(da.queries[0], seq_ctxt)
    xrays = get_xrays(da, da.queries[0], seq_ctxt)

    if xrays.shape[0] != da.queries[0].winsize*2:
        print(f"Xray data for {seq_ctxt} is incomplete")
        continue
    if phi_psi_dist.shape[0] == 0:
        print(f"No pdbmine data for {seq_ctxt}")
        continue
    if phi_psi_dist.shape[0] < 100:
        print(f"Not enough pdbmine data for {seq_ctxt}")
        continue

    precomputed_dists = precompute_dists(phi_psi_dist.iloc[:,:da.queries[0].winsize*2])

    # DBSCAN
    # for eps in [15, 30, 45, 60, 75, 90]:
    #     n_clusters_dbscan = assign_clusters_dbscan(phi_psi_dist, precomputed_dists, eps)
    #     dbscan_n_unassigned = (phi_psi_dist.cluster == -1).sum()
    #     phi_psi_dist_trimmed = phi_psi_dist[phi_psi_dist.cluster != -1].copy()
    #     precomputed_dists_trimmed = precomputed_dists[phi_psi_dist.cluster != -1][:,phi_psi_dist.cluster != -1].copy()
    #     dbscan_xray_score = calc_xray_score(phi_psi_dist_trimmed, xrays, da.queries[0], precomputed_dists_trimmed)[0]
    #     dbscan_sil_score = silhouette_score(precomputed_dists_trimmed, phi_psi_dist_trimmed.cluster, metric='precomputed')

    #     target_cluster = get_target_cluster(q, phi_psi_dist_trimmed, xrays)
    #     cluster_medoid = get_cluster_medoid(phi_psi_dist_trimmed, precomputed_dists_trimmed, target_cluster, da.queries[0])
    #     icov = estimate_icov(q, phi_psi_dist_trimmed[phi_psi_dist_trimmed.cluster == target_cluster], cluster_medoid)
    #     if icov is None:
    #         continue    

    #     # xray_maha_
    #     dbscan_xray_diff = diff(xrays, cluster_medoid)
    #     dbscan_xray_maha = np.sqrt(dbscan_xray_diff @ icov @ dbscan_xray_diff)
    #     dbscan_results.append([i, seq_ctxt, eps, n_clusters_dbscan, dbscan_n_unassigned, dbscan_xray_score, dbscan_xray_maha, dbscan_sil_score])

    # HDBSCAN
    # for cluster_selection_epsilon in [5,10,15,20,25,30,35,40]:
    for cluster_selection_epsilon in [45,50,55,60,65,70,75,80]:
        n_clusters = assign_clusters(phi_psi_dist, precomputed_dists, cluster_selection_epsilon)
        n_unassigned = (phi_psi_dist.cluster == -1).sum()
        phi_psi_dist_trimmed = phi_psi_dist[phi_psi_dist.cluster != -1].copy()
        precomputed_dists_trimmed = precomputed_dists[phi_psi_dist.cluster != -1][:,phi_psi_dist.cluster != -1].copy()
        xray_score = calc_xray_score(phi_psi_dist_trimmed, xrays, da.queries[0], precomputed_dists_trimmed)[0]
        
        target_cluster = get_target_cluster(q, phi_psi_dist_trimmed, xrays)
        cluster_medoid = get_cluster_medoid(phi_psi_dist_trimmed, precomputed_dists_trimmed, target_cluster, da.queries[0])
        icov = estimate_icov(q, phi_psi_dist_trimmed[phi_psi_dist_trimmed.cluster == target_cluster], cluster_medoid)
        if icov is None:
            continue    

        # xray_maha_
        xray_diff = diff(xrays, cluster_medoid)
        xray_maha = np.sqrt(xray_diff @ icov @ xray_diff)

        if n_clusters > 1:
            sil_score = silhouette_score(precomputed_dists_trimmed, phi_psi_dist_trimmed.cluster, metric='precomputed')
        else:
            sil_score = np.nan

        hdbscan_results.append([i, seq_ctxt, cluster_selection_epsilon, n_clusters, n_unassigned, xray_score, xray_maha, sil_score])

    # dbscan_results_df = pd.DataFrame(dbscan_results, columns=['i', 'seq_ctxt', 'eps', 'n_clusters', 'n_unassigned', 'xray_score', 'xray_maha', 'sil_score'])
    # dbscan_results_df.to_csv('dbscan_results.csv', index=False)

    hdbscan_results_df = pd.DataFrame(hdbscan_results, columns=['i', 'seq_ctxt', 'clu_sel_eps', 'n_clusters', 'n_unassigned', 'xray_score', 'xray_maha', 'sil_score'])
    hdbscan_results_df.to_csv('hdbscan_results2.csv', index=False)

  0%|          | 0/367 [00:00<?, ?it/s]

 50%|█████     | 184/367 [12:53<01:30,  2.02it/s]  

Xray data for MTETFKP is incomplete
Xray data for TETFKPT is incomplete


 50%|█████     | 185/367 [12:53<01:10,  2.59it/s]

Xray data for NIFQAYK is incomplete


 54%|█████▍    | 198/367 [13:09<01:23,  2.04it/s]

Not enough pdbmine data for TYMIFMG


 62%|██████▏   | 226/367 [13:31<01:09,  2.03it/s]

Xray data for SNSFKTI is incomplete
Xray data for NSFKTIT is incomplete


 62%|██████▏   | 227/367 [13:31<00:54,  2.57it/s]

Xray data for YGQRMLT is incomplete


100%|██████████| 367/367 [22:09<00:00,  3.62s/it]


In [126]:
dbscan_results = pd.read_csv('dbscan_results.csv')

In [131]:
dbscan_results#.groupby('eps').describe()

Unnamed: 0,i,seq_ctxt,eps,n_clusters,n_unassigned,xray_score,sil_score
0,0,WNLDKNL,15,56,1453,30.519999,0.618063
1,0,WNLDKNL,30,60,688,52.928243,0.434653
2,0,WNLDKNL,45,42,355,76.317348,0.406587
3,0,WNLDKNL,60,31,227,163.761011,0.047976
4,0,WNLDKNL,75,21,159,163.761011,0.147090
...,...,...,...,...,...,...,...
2155,366,VAVNRHQ,30,20,149,29.290299,0.708535
2156,366,VAVNRHQ,45,19,90,30.380898,0.730533
2157,366,VAVNRHQ,60,18,65,30.380898,0.754351
2158,366,VAVNRHQ,75,18,56,30.380898,0.759585


In [None]:
sns.kdeplot(cluster_results.sil_score_dbscan, label='DBSCAN')
sns.kdeplot(cluster_results.sil_score, label='HDBSCAN')
plt.legend()

In [14]:
n_cluster_plot = 10
q = da.queries[0]
# seq_ctxt = da.seqs_for_window[0]
seq_ctxt = da.seqs[da.seqs.tolist().index('AGILAGF')]
phi_psi_dist = get_phi_psi_dist(q, seq_ctxt)
n_points = phi_psi_dist.shape[0]
xray_point, xrays = get_xrays(da, q, seq_ctxt, return_df=True)
preds = get_preds(da, q, seq_ctxt)
# pred_id = 'T1024TS063_5'
pred_id = da.protein_ids[0]
pred = preds.loc[pred_id].values.reshape(2,-1)

precomputed_dists = precompute_dists(phi_psi_dist.iloc[:,:q.winsize*2])
n_clusters = assign_clusters(phi_psi_dist, precomputed_dists, 80)
intracluster_dists = calc_intra_cluster(phi_psi_dist, precomputed_dists)
print(n_clusters, phi_psi_dist.cluster.value_counts())
sil_score = silhouette_score(phi_psi_dist[phi_psi_dist.cluster != -1].iloc[:,:-1], phi_psi_dist[phi_psi_dist.cluster != -1].cluster)
# print(sil_score)
# xray_sil, nearest_cluster = xray_sil_score(phi_psi_dist, xray_point, q)
n_unassigned = (phi_psi_dist.cluster == -1).sum()
precomputed_dists = precomputed_dists[phi_psi_dist['cluster'] != -1][:,phi_psi_dist['cluster'] != -1]
phi_psi_dist = phi_psi_dist[phi_psi_dist['cluster'] != -1]
xray_sil, nearest_cluster = calc_xray_score(phi_psi_dist, xray_point, q, intracluster_dists)
pred_sil, nearest_cluster_pred = calc_xray_score(phi_psi_dist, pred.flatten(), q, intracluster_dists)
print('Intracluster', nearest_cluster, intracluster_dists[nearest_cluster])
print('Xray:', xray_sil)
print('Pred:', pred_sil)

clusters = phi_psi_dist.groupby('cluster').count().sort_values('phi_0', ascending=False).index.values
clusters = np.concatenate([[nearest_cluster], clusters[clusters != nearest_cluster]])
# clusters = np.concatenate([[nearest_cluster_pred], clusters[clusters != nearest_cluster_pred]])
clusters_plot = clusters[:n_cluster_plot]
# cluster_aves = phi_psi_dist.groupby('cluster').mean().loc[clusters_plot]
medoids = []
for cluster in clusters:
    medoid = get_cluster_medoid(phi_psi_dist, precomputed_dists, cluster, q)
    medoids.append(medoid)
medoids = np.array(medoids)
print(medoids)

colors = sns.color_palette('Dark2', n_clusters)
fig, axes = plt.subplots(len(clusters_plot), q.winsize, figsize=(16, min(n_cluster_plot, len(clusters_plot))*4), sharey=True, sharex=True)
for i,axrow in enumerate(axes):
    for j, ax in enumerate(axrow):
        cluster_dist = phi_psi_dist[phi_psi_dist.cluster == clusters_plot[i]]

        sns.scatterplot(data=phi_psi_dist[phi_psi_dist.cluster != clusters_plot[i]], x=f'phi_{j}', y=f'psi_{j}', ax=ax, label='Other Clusters', color='tab:blue', alpha=0.5)
        sns.scatterplot(data=cluster_dist, x=f'phi_{j}', y=f'psi_{j}', ax=ax, label=f'Cluster {clusters_plot[i]}', color=colors[i])
        # ax.scatter(xrays.phi.iloc[j], xrays.psi.iloc[j], color='tab:red', marker='X', label='X-ray', zorder=1000)
        # ax.scatter(pred[0,j], pred[1,j], color='tab:orange', marker='X', label=pred_id, zorder=1000)
        ax.scatter(medoids[i].reshape(2,-1)[0,j], medoids[i].reshape(2,-1)[1,j], color='black', marker='X', label='Cluster Centroid', zorder=1000)

        def add_conn(xyA, xyB, color, lw, **kwargs):
            con = ConnectionPatch(
                xyA=xyA, 
                xyB=xyB, 
                coordsA="data", coordsB="data", 
                axesA=axrow[j], axesB=axrow[j+1], 
                color=color, lw=lw, linestyle='--', alpha=0.5, **kwargs
            )
            fig.add_artist(con)
        if j < q.winsize - 1:
            # TODO draw lines for 50 points closest to centroid
            for k, row in cluster_dist.sample(min(cluster_dist.shape[0], 50)).iterrows():
                add_conn((row[f'phi_{j}'], row[f'psi_{j}']), (row[f'phi_{j+1}'], row[f'psi_{j+1}']), colors[i], 1)
            # add_conn((xrays.phi.iloc[j], xrays.psi.iloc[j]), (xrays.phi.iloc[j+1], xrays.psi.iloc[j+1]), 'tab:red', 5, zorder=100)
            # add_conn((pred[0,j], pred[1,j]), (pred[0,j+1], pred[1,j+1]), 'tab:orange', 5, zorder=100)
            add_conn((medoids[i].reshape(2,-1)[0,j], medoids[i].reshape(2,-1)[1,j]), (medoids[i].reshape(2,-1)[0,j+1], medoids[i].reshape(2,-1)[1,j+1]), 'black', 5, zorder=100)

        ax.set_xlim(-180, 180)
        ax.set_ylim(-180, 180)
        ax.set_xlabel('')
        if j == q.winsize - 1:
            ax.legend()
        else:
            ax.legend().remove()
        if i == 0:
            ax.set_title(xrays.iloc[j].res)
        if j == 0:
            ax.set_ylabel(f'Cluster {clusters_plot[i]} [{cluster_dist.shape[0]}]')
fig.supxlabel('Phi')
fig.supylabel('Psi')
fig.suptitle(
    # f'Clustered Phi/Psi Distributions for {seq_ctxt} in protein {da.casp_protein_id}: N={n_points} Silhouette Score: {sil_score:.2f}, X-ray Score [Cluster {nearest_cluster}]: {xray_sil:.2f}, Prediction Score [Cluster {nearest_cluster_pred}]: {pred_sil:.2f}', 
    f'Clustered Phi/Psi Distributions for {seq_ctxt} in protein {da.casp_protein_id}: N={n_points} ({n_unassigned} unassigned) Silhouette Score: {sil_score:.2f}, X-ray Score [Cluster {nearest_cluster}]: {xray_sil:.2f}', 
    y=1.01
)
plt.tight_layout()

7 cluster
 1    1871
-1    1225
 0    1065
 3     300
 6     192
 2     121
 5     118
 4      90
Name: count, dtype: int64


TypeError: unhashable type: 'Series'

In [None]:
import re
with open('seq.ss', 'r') as file:
    sequence = ""
    ss = ""
    conf = ""
    
    for line in file:
        if line.startswith("seq:"):
            # Extract sequence part from the line
            sequence += line.split()[2]
        elif line.startswith("SS:"):
            # Extract secondary structure part from the line
            ss += line.split()[1]
        elif line.startswith("conf:"):
            # Extract confidence part from the line
            conf += line.split()[1]

jpred = []
jpred_map = {'H': 0, 'E': 1, 'c': 2, 'C': 2, '-': 3}
jpred_map_inv = {v: k for k, v in jpred_map.items()}
j = 0
i = 0
while(i < len(da.xray_phi_psi)):
    row = da.xray_phi_psi.iloc[i]
    if da.sequence[j] == row['res']:
        jpred.append(jpred_map[ss[j]])
        j += 1
        i += 1
    else:
        j += 1
da.xray_phi_psi['jpred'] = jpred

running_val = da.xray_phi_psi.iloc[0]['jpred']
bouts = []
bout_start = 0
for i,row in da.xray_phi_psi.iterrows():
    if row['jpred'] != running_val:
        bouts.append((bout_start, i, running_val))
        bout_start = i + 1
        running_val = row['jpred']
bouts.append((bout_start, i, running_val))

fig, ax = plt.subplots(1, 1, figsize=(10, 5))
# ax.plot(np.arange(len(da.xray_phi_psi)), da.xray_phi_psi.da, label='Old Score', linestyle='--', alpha=0.75)
ax.plot(np.arange(len(da.xray_phi_psi)), da.xray_phi_psi.maha, label='Maha Score')
for i,bout in enumerate(bouts):
    if bout[2] == 2:
        if i == 0:
            ax.axvspan(bout[0], bout[1], color='red', alpha=0.3, label='Loop')
        else:
            ax.axvspan(bout[0], bout[1], color='red', alpha=0.3)
ax.legend(loc='upper right')
ax.set_title(f'Score per residue for X-Ray structure of {da.casp_protein_id}')
ax.set_xlabel('Residue')
ax.set_ylabel('Score')
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
sns.lineplot(da.phi_psi_predictions, x='pos', y='af_score', label='Average Maha Score')
for i,bout in enumerate(bouts):
    if bout[2] == 2:
        if i == 0:
            ax.axvspan(bout[0], bout[1], color='red', alpha=0.3, label='Loop')
        else:
            ax.axvspan(bout[0], bout[1], color='red', alpha=0.3)
ax.legend(loc='upper right')
ax.set_title(f'Average Score per residue for Predicted Structures of {da.casp_protein_id}')
ax.set_xlabel('Residue')
ax.set_ylabel('Score')
plt.show()

In [None]:
for m in medoids:
    print(f'{{{' '.join([str(mi) for mi in m])}}}')

In [None]:
def get_sil_score(phi_psi_dist, precomputed_dists):
    ss = []
    sil_samples = np.zeros(phi_psi_dist.shape[0])
    for i in phi_psi_dist.cluster.unique():
        n_points = phi_psi_dist[phi_psi_dist.cluster == i].shape[0]
        if n_points == 1:
            ss.append(1)
            break
        a = precomputed_dists[phi_psi_dist.cluster == i][:,phi_psi_dist.cluster == i].sum(axis=1) / (n_points - 1)
        bs = []
        for j in phi_psi_dist.cluster.unique():
            if i == j:
                continue
            b = precomputed_dists[phi_psi_dist.cluster == i][:,phi_psi_dist.cluster == j].sum(axis=1) / phi_psi_dist[phi_psi_dist.cluster == j].shape[0]
            bs.append(b)
        bs = np.stack(bs).T
        b = bs.min(axis=1)
        s = (b - a) / np.maximum(a, b)
        sil_samples[phi_psi_dist.cluster == i] = s
        ss.append(s)
    s = np.concatenate(ss).mean()
    return s, sil_samples

In [None]:
# max_sil_avg = -1
# for k in range(2, min(phi_psi_dist.shape[0], 2**q.winsize)):
#     kmeans = KMeans(n_clusters=k).fit(phi_psi_dist.values)
#     sil_avg = silhouette_score(phi_psi_dist.values, kmeans.labels_)
#     if sil_avg > max_sil_avg:
#         max_sil_avg = sil_avg
#         phi_psi_dist['cluster'] = kmeans.labels_
#         chosen_centroids = kmeans.cluster_centers_
#         n_clusters = k

# phi_psi_dist['cluster'] = DBSCAN(eps=75, min_samples=5).fit(phi_psi_dist.values).labels_
# n_clusters = len(phi_psi_dist.cluster.unique())
# print(n_clusters, phi_psi_dist.cluster.value_counts())

In [None]:

def calc_score(q, preds, phi_psi_dist, intracluster_dists, xrays=None, afs=None):
    # Distance to nearest cluster average
    d = np.linalg.norm(diff(preds.iloc[:,:q.winsize*2].values[:,np.newaxis], phi_psi_dist.iloc[:,:q.winsize*2].values), axis=2)
    average_dists = []
    clusters = phi_psi_dist.cluster.unique()
    for c in clusters:
        average_dists.append(d[:,phi_psi_dist.cluster == c].mean(axis=1))
    average_dists = np.array(average_dists).T
    min_dists = average_dists.min(axis=1)
    nearest_clusters = average_dists.argmin(axis=1)

    preds_dist = pd.DataFrame(index=preds.index, columns=['dist'])
    preds_dist['dist'] = preds_dist['dist'].astype(float)
    for c in clusters:
        cluster_points = phi_psi_dist[phi_psi_dist.cluster == c].iloc[:,:q.winsize*2].values
        cluster_avg = cluster_points.mean(axis=0)
        preds_c = preds[nearest_clusters == c].iloc[:,:q.winsize*2].values
        preds_dists_c = np.linalg.norm(diff(preds_c, cluster_avg), axis=1)
        preds_dist.loc[nearest_clusters == c, 'dist'] = preds_dists_c

    return preds_dist.dist.values

    # # Distance to cluster with most points
    # cluster_counts = phi_psi_dist.groupby('cluster').size()
    # c = cluster_counts.idxmax()
    # cluster_points = phi_psi_dist[phi_psi_dist.cluster == c].iloc[:,:q.winsize*2].values
    # cluster_avg = cluster_points.mean(axis=0)

    # preds_dist = np.linalg.norm(diff(preds.iloc[:,:q.winsize*2].values, cluster_avg), axis=1)
    # return preds_dist

    # # Distance to cluster average nearest to xray point (control)
    # d = np.linalg.norm(diff(xrays[np.newaxis,:], phi_psi_dist.iloc[:,:q.winsize*2].values), axis=1)
    # d = pd.DataFrame({'d': d, 'c': phi_psi_dist.cluster})
    # nearest_cluster = d.groupby('c').d.mean().idxmin()
    # cluster_points = phi_psi_dist[phi_psi_dist.cluster == nearest_cluster].iloc[:,:q.winsize*2].values
    # cluster_avg = cluster_points.mean(axis=0)
    # preds_dist = np.linalg.norm(diff(preds.iloc[:,:q.winsize*2].values, cluster_avg), axis=1)
    # return preds_dist

    # # Distance to cluster average nearest to af
    # d = np.linalg.norm(diff(afs[np.newaxis,:], phi_psi_dist.iloc[:,:q.winsize*2].values), axis=1)
    # d = pd.DataFrame({'d': d, 'c': phi_psi_dist.cluster})
    # nearest_cluster = d.groupby('c').d.mean().idxmin()
    # cluster_points = phi_psi_dist[phi_psi_dist.cluster == nearest_cluster].iloc[:,:q.winsize*2].values
    # cluster_avg = cluster_points.mean(axis=0)
    # preds_dist = np.linalg.norm(diff(preds.iloc[:,:q.winsize*2].values, cluster_avg), axis=1)
    # return preds_dist

    # # Distance to xray point
    # preds_dist = np.linalg.norm(diff(preds.iloc[:,:q.winsize*2].values, xrays), axis=1)
    # return preds_dist


# def calc_xray_score(phi_psi_dist, xrays, q, intracluster_dists):
def calc_xray_score(phi_psi_dist, xrays, q, intracluster_dists, afs=None):
    # Distance to nearest cluster average
    d = np.linalg.norm(diff(xrays[np.newaxis,:], phi_psi_dist.iloc[:,:q.winsize*2].values), axis=1)
    d = pd.DataFrame({'d': d, 'c': phi_psi_dist.cluster})
    nearest_cluster = d.groupby('c').d.mean().idxmin()
    cluster_points = phi_psi_dist[phi_psi_dist.cluster == nearest_cluster].iloc[:,:q.winsize*2].values
    cluster_avg = cluster_points.mean(axis=0)
    
    xray_dist = np.sqrt((diff(xrays, cluster_avg)**2).sum())

    return xray_dist, nearest_cluster

    # # Distance to cluster average with most points
    # cluster_counts = phi_psi_dist.groupby('cluster').size()
    # c = cluster_counts.idxmax()
    # cluster_points = phi_psi_dist[phi_psi_dist.cluster == c].iloc[:,:q.winsize*2].values
    # cluster_avg = cluster_points.mean(axis=0)
    # xray_dist = np.sqrt((diff(xrays, cluster_avg)**2).sum())
    # return xray_dist, c

    # # Distance to cluster average nearest to af
    # d = np.linalg.norm(diff(afs[np.newaxis,:], phi_psi_dist.iloc[:,:q.winsize*2].values), axis=1)
    # d = pd.DataFrame({'d': d, 'c': phi_psi_dist.cluster})
    # nearest_cluster = d.groupby('c').d.mean().idxmin()

    # cluster_points = phi_psi_dist[phi_psi_dist.cluster == nearest_cluster].iloc[:,:q.winsize*2].values
    # cluster_avg = cluster_points.mean(axis=0)
    # xray_dist = np.sqrt((diff(xrays, cluster_avg)**2).sum())
    # return xray_dist, nearest_cluster

    # # Distance to xray point
    # d = np.linalg.norm(diff(xrays[np.newaxis,:], phi_psi_dist.iloc[:,:q.winsize*2].values), axis=1)
    # d = pd.DataFrame({'d': d, 'c': phi_psi_dist.cluster})
    # nearest_cluster = d.groupby('c').d.mean().idxmin()
    # return 0, nearest_cluster

