# Cluster Transitions for Dynamic by Design Data

## Imports

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

## Acquire Trajectories

In [None]:
data_dir = Path('./data').resolve()
xtc_dir = data_dir.joinpath('Disordered_By_Design/XTC_files')

unique_names = [name.stem for name in xtc_dir.glob('*.xtc')]

print(f"{len(unique_names)} unique starting configurations:")
print(unique_names)

In [None]:
from Bio import PDB

pdb_dir = data_dir.joinpath('Disordered_By_Design/2KMV/')
sample_file = pdb_dir.joinpath('2KMV_01_02.pdb')

sample_pdb = PDB.PDBParser().get_structure('sample', sample_file)
print(f"Sample structure with {len([_ for atom in sample_pdb.get_atoms()])} models")

In [None]:
import mdtraj

sample_u = mdtraj.load_xtc(
    data_dir.joinpath('Disordered_By_Design/XTC_files/md_0_1_align_2KMV_01_02.xtc'),
    top=data_dir.joinpath('Disordered_By_Design/2KMV/2KMV_01_02.pdb'),
    atom_indices=range(0, 2834)
)
print(sample_u)

In [None]:
from MDAnalysis import Universe

sample_u = Universe(
    topology=data_dir.joinpath('Disordered_By_Design/2KMV/2KMV_01_02.pdb'), topology_format='pdb',
    trajectory=data_dir.joinpath('Disordered_By_Design/XTC_files/md_0_1_align_2KMV_01_02.xtc'), format='xtc'
)

for traj in sample_u.trajectory:
    print(f"Frame {traj.frame} has {len(sample_u.atoms)} atoms")
print(f"Found {len(sample_u.atoms)} atoms")
print(f"Found {len(sample_u.residues)} residues")
print(f"Found {len(sample_u.segments)} segments")

## Clustering

### Get clusters

In [None]:
K = 2000
import pickle
with open(f'clusters-{K//1000}K.pkl', 'rb') as f:
    clusters = pickle.load(f)

In [None]:
C = clusters['X']
C = np.stack(C)
Cangles = np.array([clusters['phi0'], clusters['psi0'], clusters['phi1'], clusters['psi1']]).T

In [None]:
from sklearn.neighbors import NearestNeighbors
neigh = NearestNeighbors().fit(C.reshape(C.shape[0], -1))
#indices = neight.kneighbors(C.reshape(C.shape[0], -1), n_neighbors=1, return_distance=False)

### Coordinate getter function

In [None]:
from Bio.PDB import Selection, Atom, Residue, Structure

# data format seems to have changed, change alex' parser to give same results
# # TODO: 'get_coordinates' function not really documented well enough to be sure, ask alex
def get_coordinates(traj_path: Path, top_path: Path, filter_atoms=('N','CA','C','O')):
    # load topology separately to get filter condition
    topology: mdtraj.Topology = mdtraj.load_pdb(top_path).topology
    
    # we expect the atoms to be in the same order as filter_atoms
    # it seems to always be the case for our pdb topologies, but this should somehow be checked
    atom_order = {filter_atom: ii_atom for ii_atom, filter_atom in enumerate(filter_atoms)}
    atom_filter = set(filter_atoms)

    # get all residues and atoms that have all filter atoms in them
    valid_atom_ids = []
    valid_res_ids = []

    # this should be a dictionary of residue index to residue name?
    # is filtered afterwards for consecutive ids for some reason
    valid_residue_names = {}
    for residue in topology.residues:
        # get all atoms in this residue
        sub_indices = np.empty(len(filter_atoms), dtype=int)
        valid_atom_counter = 0
        for atom in residue.atoms:
            if atom.name in atom_filter:
                # count how many of filters are in residue
                sub_indices[atom_order[atom.name]] = atom.index
                valid_atom_counter += 1

        if valid_atom_counter == len(filter_atoms):
            # if all filter atoms are in residue, add to valid lists
            valid_residue_names[residue.index] = str(residue)
            valid_res_ids.append(residue.index)
            valid_atom_ids.extend(list(sub_indices))

    # only load atom ids according to filter condition
    universe: mdtraj.Trajectory = mdtraj.load_xtc(traj_path, top=top_path, atom_indices=valid_atom_ids)

    # extract coordinates as n_frames x n_residues x n_(filter_)atoms x 3
    coords = universe.xyz.reshape((-1, len(valid_res_ids), len(filter_atoms), 3))
    return coords, valid_residue_names

sample_top = data_dir.joinpath(f"Disordered_By_Design/2KMV/{unique_names[0].split('align_')[-1]}.pdb")
coords, valid_residue_names = get_coordinates(
    traj_path=xtc_dir.joinpath(unique_names[0]+'.xtc'),
    top_path=sample_top
)

In [None]:
print(valid_residue_names)

### Canonicalize the coordinates and clusterization functionalities

In [None]:
import numba

@numba.njit(parallel=True)
def canonize(X: np.ndarray):
    #X = np.vstack(coords)
    X = X - X[2,:]
    
    e1 = X[4,:]-X[2,:]
    e1 = e1/np.linalg.norm(e1)
    
    e3 = np.cross(X[2,:]-X[4,:], X[5,:]-X[4,:])
    e3 = e3/np.linalg.norm(e3)
    
    e2 = np.cross(e3, e1)
    e2 = e2/np.linalg.norm(e2)
    
    U = np.zeros((3,3))
    U[:,0] = e1
    U[:,1] = e2
    U[:,2] = e3    
    #U = np.vstack([e1, e2, e3]).T

    X = X @ U
    return X
    
@numba.jit
def clusterize(C, xcan):
    d = np.sqrt(((xcan[None,:,:]-C)**2).sum(2).mean(1))
    i = np.argmin(d)
    return i, d[i]

def clusterize_fast(neigh, xcan):
    d, i = neigh.kneighbors(xcan.reshape(xcan.shape[0], -1), n_neighbors=1, return_distance=True)
    return i[:,0], d[:,0]

In [None]:
def canonize_batch(X: np.ndarray):
    # takes us from a ... x 2*n_atoms x 3 array to "canonical descriptors"
    # these should actually be ... x 2*n_atoms*3 by simply collapsing at some point
    # but for some reason this shape is kept until just before clustering
    
    #X = np.vstack(coords)

    # center around C1
    X = X - X[:,2,:][:,None,:]
    
    # angle N2-C1
    e1 = X[:,4,:]-X[:,2,:]
    e1 = e1/np.linalg.norm(e1, axis=1)[:,None]
    
    # something between C1-N2 and Ca2-N2
    e3 = np.cross(X[:,2,:]-X[:,4,:], X[:,5,:]-X[:,4,:], axis=1)
    e3 = e3/np.linalg.norm(e3, axis=1)[:,None]
    
    e2 = np.cross(e3, e1, axis=1)
    e2 = e2/np.linalg.norm(e2, axis=1)[:,None]
    
    U = np.stack([e1, e2, e3], axis=2)
    return np.einsum('nij,njk->nik', X, U) #X

In [None]:
def clusterize_traj(coords, residues):
    # this only works because it sneakily pulls the cluster centers C from the global scope
    import multiprocessing as mp
    from process import canonize

    inputs = []
    indices = []
    res_ids = list(residues.keys())
    pairs = [res for res in range(coords.shape[1]-1) if res_ids[res+1] == res_ids[res]+1]
    for frame in range(coords.shape[0]):
        for pair, res in enumerate(pairs):
            x = coords[frame,res:res+2,:]
            inputs.append(x)
            indices.append((frame, res, pair))

    labels = np.zeros((coords.shape[0], len(pairs), 2))
    with mp.Pool(mp.cpu_count(),) as pool:
        for k, (xcan, (frame, res, pair)) in enumerate(zip(pool.imap(canonize, inputs, chunksize=100), indices)):
            idx, d = clusterize(C, xcan)
            labels[frame, pair, :] = [idx, d]
    return labels, pairs


def clusterize_traj_fast(coords, residues):
    # this only works because it sneakily pulls the nearest neighbors object 'neigh' from the global scope
    inputs = []

    # indices contains frame, how manyth residue and pair number
    indices = []
    res_ids = list(residues.keys())
    # find all residues with.. consecutive ids, will then be used as AD pairs
    pairs = [res for res in range(coords.shape[1]-1) if res_ids[res+1] == res_ids[res]+1]

    # this arranges the input data so that it can be canonized in parallel, will contain duplicates
    # bit of a waste of memory, could be optimised, but for small number of pairs ok
    # but should be: n_frames, n_ad_pairs, 8 (2*n_atoms), 3 (n_coordinates)
    for frame in range(coords.shape[0]):
        for pair, res in enumerate(pairs):
            # get this and next residue
            x = coords[frame, res:res+2, :]
            # combine to AD pairs
            inputs.append(np.vstack(x))
            # collect frame, residue position in res_ids, and pair number
            indices.append((frame, res, pair))

    # now case to numpy array, this probably wouldn't be necessary if canonize_batch was properly done
    inputs = np.stack(inputs, axis=0)
    # this gives the 3 canonical coordinates for n_frames*n_pairs, n_atoms*2, 3
    xcan = canonize_batch(inputs)
    # the stacking seems redundant as there is nothing to stack
    # maybe this is a mistake and this should be stacking the last two axes to create one n_atoms*2*3 descriptor?
    # or left over from when this was done using lists
    xcan = np.stack(xcan)

    # clusterize_fast actually takes over the reshaping to make this canonical representation a descriptor
    idxs, dists = clusterize_fast(neigh, xcan)
    
    # now reshape to original setup which is n_frames x n_pairs 
    # in each entry put nearest cluster center index and distance
    labels = np.zeros((coords.shape[0], len(pairs), 2))
    for i, d, (frame, res, pair) in zip(idxs, dists, indices):
        labels[frame, pair, :] = [i, d]
    
    # return (cluster center index, distance), index of first residue in pair (next pair index is next residue)
    return labels, pairs

In [None]:
def temporal_rms(labels, C):
    # if we assume that distances in C are meaningful
    # now we get the variance of position in descriptor space is meaningful
    # why we cannot just do autocorrelation of the distances is beyond me
    rms = []
    for pair in range(labels.shape[1]):
        idx = [int(i) for i in sorted(set(labels[:,pair,0]))]
        rms.append(np.sqrt(C[idx,:].reshape(len(idx), -1).var(0).sum()))
    return np.array(rms)

### Actually do for trajectories

In [None]:
results_dict = {}
for traj_path in xtc_dir.glob('*.xtc'):
    traj_name = traj_path.stem
    sample_top = data_dir.joinpath(f"Disordered_By_Design/2KMV/{traj_name.split('align_')[-1]}.pdb")
    # by using the standard filter we get the backbone atoms
    coords, residues = get_coordinates(traj_path, sample_top)
    labels, pairs = clusterize_traj_fast(coords, residues)
    rms = temporal_rms(labels, C)
    results_dict[traj_name] = {'rms': rms, 'path': labels}
    print(f"{traj_path.stem}: {rms.mean()}")

## No idea what this does

In [None]:
from numpy import linalg as la

def hitting_times(T, max_iter = 10000):
    one = np.ones(T.shape) - np.eye(T.shape[0])
    k = np.zeros(T.shape[0])
    for i in range(max_iter):
        if i % 1000 == 0:
            print(i)
        k_old = k
        k = one + T.dot(k)
        k -= np.diag(np.diag(k))
        if la.norm(k-k_old) < 1e-6:
            break
    return k

In [None]:
# this calculates the transition matrix of going from one cluster to another
# P[from_c, to_c]

P = np.zeros((C.shape[0], C.shape[0]))
for key, results in results_dict.items():
    traj = results['path']
    
    for i, j in zip(traj[0:-1,:,0].reshape(-1), traj[1:,:,0].reshape(-1)):
        P[int(i),int(j)] += 1/(traj.shape[0]-1)

In [None]:
time_step = 0.01 # ns = 10ps
n_from = P.sum(1)
n_to = P.sum(0)
idx_nonzero = (n_from>0) & (n_to > 0) # this should be | right?
Pnrm = P[idx_nonzero,:][:,idx_nonzero] / n_from[idx_nonzero][:,None]
Cnrm = C[idx_nonzero,:]
Cangles_nrm = Cangles[idx_nonzero,:]

In [None]:
label_lut = {}
count = 0
for l, idx in enumerate(idx_nonzero):
    if idx:
        count += 1
        label_lut[l] = count-1
    else:
        label_lut[l] = -1

In [None]:
all_labels = np.array([label_lut[l] for l in clust_labels['labels']])
idx_valid = (all_labels >= 0)
all_labels = all_labels[idx_valid]
all_angles = clust_labels['angles'][idx_valid,:]

In [None]:
plt.imshow(np.log(Pnrm[:, :]))
plt.colorbar()

In [None]:
H = hitting_times(Pnrm, max_iter = 10000) * time_step

In [None]:
t = 5
plt.imshow(np.exp(-np.minimum(H,H.T)/t))
plt.colorbar()
#ticks = np.array(list(range(0,1000)))
#plt.xticks(ticks, ticks)
#plt.yticks(ticks, ticks)
print()