In [None]:
import matplotlib.pyplot as plt
from matplotlib import colors
import collections
import numpy as np
import math
import MDAnalysis as mda
from MDAnalysis.analysis import align, rms, diffusionmap
import ipynb_importer
import compute, parsers
%matplotlib inline

def visualize_clustering(data, i = None, key = None):
    '''input: dict
           data of one superfeature'''
    
    fig, axes = plt.subplots(1, 4, figsize=(10,2))
    fig.suptitle(f"{i} {key}")

    for axi, d in enumerate([0, 1]):
        ax = plt.subplot(1, 4, axi+1)
        ax = plt.scatter(data['points'][:, d], data['points'][:, d+1], marker='.', s=5)

    for axi, d in enumerate([0, 1]):
        ax = plt.subplot(1, 4, axi+3)
        ax = plt.scatter(data['points'][:, d], data['points'][:, d+1], 
                    c = data['clustering'].labels, 
                    cmap=plt.cm.get_cmap('tab20c').reversed(),
                    marker='.', 
                    s=5,
                   )
    fig.tight_layout(pad=0.5)
    
    
def plot_histogram(distances, i = None, key = None):
    '''input: ndarray n*n
       output: count of each bin'''
    plt.figure(figsize=(3,2))
    weights = np.zeros_like(distances.flatten()) + 1. / distances.flatten().size
    y, x, _ = plt.hist(distances.flatten(), bins=100, color='y', weights=weights)
    plt.ylabel("frequency")
    plt.xlabel("pairwise distance")
    plt.title(f"{i} {key}")
    plt.axvline(min_value)
    plt.show()
    
    return y


def plot_bar_code(pam, n_drop = 0):
    n_frames = len(pam.labels_)
    x = np.array(range(n_frames)) + n_drop
    plt.figure(figsize=(15, 7), dpi=80)

    plt.scatter(x, pam.labels_, marker = "|")
    plt.xlabel("original frame")
    plt.ylabel("binding state")

    n_cluster = len(np.unique(pam.labels_))
    counter = dict(collections.Counter(pam.labels_))
    print("There are", n_cluster, "clusters")
    print(f"Frames within each binding state: {counter}")
    
    

def plot_radar(feature_per_state, xmin = 0.5, xmax = 0):
    '''Plot interaction frequency within each cluster as radar plot'''
    state_count = len(feature_per_state.keys())
    data_length = len(feature_per_state[0])
    # split polar coordinates
    angles = np.linspace(0, 2*np.pi, data_length, endpoint=False)
    labels = [key for key in feature_per_state[0].keys()]

    feature = []
    for i in feature_per_state:
        feature_temp = feature_per_state[i]
        temp = [i for i in feature_temp.values()]
        feature.append(temp)

    angles = np.concatenate((angles, [angles[0]]))
    labels = np.concatenate((labels, [labels[0]]))
    
    fig = plt.figure(figsize=(8, 6), dpi=100)
    
    ax = plt.subplot(111, polar=True)
    feature_map = {
        k: np.concatenate((feature[k], [feature[k][0]]))
        for k in range(state_count)
    }

    colors = ["g", "b", "r", "y", "m", "k", "c"]
    for i in feature_map:
        feature_temp = feature_map[i]
        ax.plot(angles, feature_temp, color = colors[i], label = f"state {i}")

    ax.set_thetagrids(angles*180/np.pi, labels)
    ax.set_theta_zero_location('N')
    ax.set_rlim(xmax, xmin)
    ax.set_rlabel_position(270)
    ax.set_title("interaction frequency of cluters")
    plt.legend(bbox_to_anchor = (1.2, 1.05))
    plt.show()
    

def plot_whole_rmsd(pdb_path, dcd_path, select = 'chainID X', alignment = True, n_drop = 0):
    u = mda.Universe(pdb_path, dcd_path, in_memory=True)
    ligand = u.select_atoms(select)
    
    rmsd_ls = []
    X =  np.arange(n_drop, len(u.trajectory))
    
    if alignment:
        reference_coordinates = u.trajectory.timeseries(asel = ligand).mean(axis = 1)
        reference = mda.Merge(ligand).load_new(
                    reference_coordinates[:, None, :], order = "afc")
        ref_ca = reference.select_atoms(select)
        aligner = align.AlignTraj(u, reference, select = select, in_memory = True).run()
    
        # get RMSD
        rmsd_ls = []
        X =  range(len(u.trajectory))
    
    for i in X:
        mobile = mda.Universe(pdb_path, dcd_path)
        mobile.trajectory[i]
        mobile_ca = mobile.select_atoms(select)
        rmsd_ = rms.rmsd(mobile_ca.positions, ligand.positions, superposition=False)
        rmsd_ls.append(rmsd_)
        
    plt.figure(figsize=(15,5))
    plt.plot(X, rmsd_ls)
    plt.ylim(bottom = 0)
    plt.title(f"RMSD of ligand: {round(np.mean(np.array(rmsd_ls)), 2)}")
    plt.show()

    
def plot_cluster_rmsd(pam, pdb_path, dcd_path = "output/", select = 'chainID X'):
    n_cluster = np.max(pam.labels_) + 10
    numRows =  math.ceil(n_cluster/3)
    numCols = 3
    
    plt.figure(figsize=(20,15))
    for k in np.unique(pam.labels_):
        u = mda.Universe(pdb_path, f"{dcd_path}cluster_{k}.dcd", in_memory = True)
        ligand = u.select_atoms(select)
        # reference = average structure
        reference_coordinates = u.trajectory.timeseries(asel = ligand).mean(axis = 1)
        reference = mda.Merge(ligand).load_new(
                    reference_coordinates[:, None, :], order = "afc")
        ref_ca = reference.select_atoms(select)
        aligner = align.AlignTraj(u, reference, select = select, in_memory = True).run()

        rmsd_ls = []
        X =  range(len(u.trajectory))
        mobile = mda.Universe(pdb_path, f"{dcd_path}cluster_{k}.dcd")

        for j in X:
            mobile.trajectory[j]
            mobile_ca = mobile.select_atoms(select)

            rmsd_ = rms.rmsd(mobile_ca.positions, ref_ca.positions, superposition = False)
            rmsd_ls.append(rmsd_)

        # plot
        plt.subplot(numRows, numCols, (k + 1))
        max_rmsd = np.max(np.array(rmsd_ls)) + 0.5
        plt.ylim(0, max_rmsd)
        plt.plot(X, rmsd_ls)
        plt.title(f"State {k} avg. RMSD = {round(np.mean(np.array(rmsd_ls)), 2)}")
        plt.xlabel("Frames")
        plt.ylabel("RMSD ($\AA$)")
    plt.tight_layout() 
    
def plot_diffusionmap(pdb_path, dcd_path, select = "protein or chainID X", reduce = True):
    if reduce:
        compute.reduce_frames(pdb_path, dcd_path, select = select, out_path = f"output/reduced.dcd", final_n_frame = 500)
        dcd_path = f"output/reduced.dcd"
    u = mda.Universe(pdb_path, dcd_path)
    aligner = align.AlignTraj(u, u, select = select,
                             in_memory = True).run()
    
    diffusion_matrix = diffusionmap.DistanceMatrix(u, select = 'chainID X').run()
    plt.imshow(diffusion_matrix.dist_matrix, cmap='viridis')
    plt.xlabel('Frame')
    plt.ylabel('Frame')
    plt.colorbar(label=r'RMSD ($\AA$)')
    plt.show()
    

def draw_2d_wrap_data(data, dynophore_dict, wrap_data = None):
    '''2D drawing of data after clustering of all superfeatures'''
    if wrap_data == None:
        wrap_data = parsers.get_wrap_data(data)
        
    color_dict = {}
    for key, data_ in dynophore_dict.items():
        color = data_['color']
        color_dict[key] = color
    superfeatures_colors = {superfeature_id: tuple(colors.hex2color(f"#{color}")) for superfeature_id, color in color_dict.items()}

    superfeature_name = list(dynophore_dict.keys())
    
    fig, ax = plt.subplots(figsize = (4, 4))
    for i in range(len(wrap_data)):
        superfeature_data = wrap_data[i]
        cluster_data, noise = superfeature_data[superfeature_data[:, 3] != 0], superfeature_data[superfeature_data[:, 3] == 0]
        color = [list(superfeatures_colors.values())[i]] * len(cluster_data)
        
        ax.scatter(cluster_data[:, 0], cluster_data[:, 1], c = color, label = superfeature_name[i], s = 1)
        ax.scatter(noise[:, 0], noise[:, 1], c = 'grey', s = 2)
    plt.xlabel("$x$")
    plt.ylabel("$y$")
    ax.set_title("Clustering of all superfeatures")
    ax.legend(bbox_to_anchor=(1.1, 1.05))

def draw_3d_wrap_data(data, dynophore_dict, wrap_data = None):
    %matplotlib notebook
    '''3D drawing of data after clustering of all superfeatures'''
    if wrap_data == None:
        wrap_data = parsers.get_wrap_data(data)
        
    fig = plt.figure(figsize = (4, 4))
    ax = fig.add_subplot(111, projection='3d')

    color_dict = {}
    for key, data_ in dynophore_dict.items():
        color = data_['color']
        color_dict[key] = color
    superfeatures_colors = {superfeature_id: tuple(colors.hex2color(f"#{color}")) for superfeature_id, color in color_dict.items()}
    
    superfeature_name = list(dynophore_dict.keys())

    for i in range(len(data)):
        superfeature_data = wrap_data[i]
        cluster_data, noise = superfeature_data[superfeature_data[:, 3] != 0], superfeature_data[superfeature_data[:, 3] == 0]
        color = [list(superfeatures_colors.values())[i]] * len(cluster_data)

        ax.scatter(cluster_data[:, 0], cluster_data[:, 1], cluster_data[:, 2], c = color, label = superfeature_name[i], s = 1)
        ax.scatter(noise[:, 0], noise[:, 1], noise[:, 2], c = 'grey', s = 2)
    plt.xlabel("$x$")
    plt.ylabel("$y$")
    ax.set_title("Clustering of all superfeatures")
    ax.legend(bbox_to_anchor=(1.1, 1.05))