# Workflow 4 - analyze the results
<hr>

1. Cluster interactive plots
2. Structure plots
3. Movie trajectory plots

In [1]:
#required imports
from ipywidgets import interact, interactive, fixed, interact_manual
from ipywidgets import interactive, fixed, IntSlider, VBox, HBox, Layout
import ipywidgets as widgets
import matplotlib.patheffects as PathEffects
from IPython.display import display,clear_output
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from engens.core.EnGens import EnGen
from engens.core.ClustEn import *
from engens.core.FeatureSelector import *
from engens.core.DimReduction import *
import pickle as pk
import os



In [2]:
engen = None
with open("wf3_resulting_EnGen.pickle", "rb") as file:
    engen = pk.load(file)
clust = None
with open("wf3_resulting_Clust.pickle", "rb") as file:
    clust = pk.load(file)
dimred = None
with open("wf2_resulting_Reducer.pickle", "rb") as file:
    dimred = pk.load(file)

In [3]:
%matplotlib inline

## Cluster plots
<hr>

In [4]:
traj_len = dimred.transformed_data.shape[0]

x = dimred.transformed_data[:,0]
y =  dimred.transformed_data[:,1]
x_max, y_max, x_min, y_min = (max(x), max(y), min(x), min(y))
frame_c = cm.plasma(np.linspace(0, traj_len, traj_len)/traj_len)
color_opt = "frame"
cluster_c = clust.labels[clust.chosen_index]
c=frame_c
scat_size=10

rep_fnum = clust.chosen_frames
rep_len = rep_fnum.shape[0]
rep_x = x[rep_fnum]
rep_y = y[rep_fnum]
rep_c = cm.Set1(np.linspace(0, rep_len, rep_len)/rep_len)
rep_size = 400
reps_text = []

In [5]:

from plotly.offline import iplot
from plotly.graph_objs import graph_objs as go

def view_timeline():
    pfig = go.Figure()
    scatter = go.Scatter(x=list(range(5001)), y=clust.labels[clust.chosen_index],              
            marker=dict(
                size=3,
                color=clust.labels[clust.chosen_index],
                colorscale="Viridis"
            ),
        mode="markers"
        )
    pfig.add_trace(scatter)
    for i, rf in enumerate(rep_fnum): 
        atxt = "Cluster "+str(clust.labels[clust.chosen_index][rf])+"; Frame "+str(rf)
        pfig.add_vline(x=rf, line_width=1, line_dash="dash", line_color="red", 
              annotation_text=atxt)
    pfig.update_layout(go.Layout(width=800, height=400),
    xaxis_title="trajectory frames",
    yaxis_title="cluster",
    hovermode="x"
                )
    pfig.update_yaxes(dtick="d")
    iplot(pfig)

In [6]:
def view_PCs():
    marker_labels = []
    lablist = list(clust.labels[clust.chosen_index])
    for i, elem in enumerate(lablist):
        marker_labels.append("cluster="+str(elem)+" \n frame="+str(i))
    
    pfig = go.Figure()
    scatter = go.Scatter(x=x, y=y,              
            marker=dict(
                size=3,
                color=clust.labels[clust.chosen_index],
                colorscale="Viridis"
            ),
        mode="markers",
        text = marker_labels,
        hovertemplate = "<b>%{text}</b>"
        )
    pfig.add_trace(scatter)
    scatter = go.Scatter(x=rep_x, y=rep_y,              
            marker=dict(
                size=6,
                color='red',
                line=dict(width=2,color='DarkSlateGrey')
            ),
        mode="markers",
        text = rep_fnum,
        hovertemplate = "<b>Representative (frame - %{text})</b>"
        )
    pfig.add_trace(scatter)
    pfig.update_layout(go.Layout(width=450, height=450,showlegend=False),
    xaxis_title="C1",
    yaxis_title="C2")
    iplot(pfig)

In [7]:

def view_bar():
    bar = go.Bar(x=list(range(clust.chosen_index+2)) ,
                 y=clust.cluster_weights(clust.chosen_index),
                 marker={'color': list(range(clust.chosen_index+2)), 'colorscale': "Viridis"})
    pfig = go.Figure()
    pfig.add_trace(bar)
    pfig.update_layout(go.Layout(width=450, height=450),
        xaxis_title="cluster",
        yaxis_title="cluster weight")
    pfig.update_xaxes(dtick="d")
    if not clust.thr == None:
        pfig.add_hline(y=clust.thr, line_width=2, line_dash="dash", line_color="black", 
                  annotation_text="cluster cutoff")
    iplot(pfig)

In [8]:
widg1 = interactive(view_timeline)
widg2 = interactive(view_PCs)
widg3 = interactive(view_bar)
display(VBox([HBox([widg2, widg3]), widg1]))

VBox(children=(HBox(children=(interactive(children=(Output(),), _dom_classes=('widget-interact',)), interactiv…

## Structure plots
<hr>

In [9]:
import nglview as nv
import pytraj as pt
import matplotlib.colors as mc

In [10]:
colors = cm.viridis(np.linspace(0, clust.chosen_index+2, clust.chosen_index+2)/(clust.chosen_index+2))

In [11]:
res_pdbs = []
for file in os.listdir("./res_ensemble"):
    if file[-4:] == ".pdb":
        res_pdbs.append(file)

nw = nv.NGLWidget()
for i, pdb_file in enumerate(res_pdbs):
    clust_n = clust.labels[clust.chosen_index][clust.chosen_frames[i]] 
    name = "Cluster - "+str(clust_n)
    nw.add_component("./res_ensemble/"+pdb_file, name=name,default_representation=False)
    nw[i].add_ball_and_stick(color = mc.rgb2hex(colors[clust_n]))

In [12]:
nw.center()
nw

NGLWidget()

## Trajectory movie plots
<hr>

In [13]:
traj_len = dimred.transformed_data.shape[0]

x = dimred.transformed_data[:,0]
y =  dimred.transformed_data[:,1]
x_max, y_max, x_min, y_min = (max(x), max(y), min(x), min(y))
frame_c = cm.plasma(np.linspace(0, traj_len, traj_len)/traj_len)
color_opt = "frame"
cluster_c = clust.labels[clust.chosen_index]
c=frame_c
scat_size=10

rep_fnum = clust.chosen_frames
rep_len = rep_fnum.shape[0]
rep_x = x[rep_fnum]
rep_y = y[rep_fnum]
rep_c = cm.Set1(np.linspace(0, rep_len, rep_len)/rep_len)
rep_size = 400
reps_text = []
for i, elem in enumerate(rep_fnum):
    reps_text.append("C"+str(i)+"-F"+str(elem))
    

out_frame_cmap = widgets.Output()
gradient = np.linspace(0, traj_len, traj_len)/traj_len
gradient = np.vstack((gradient, gradient))
with out_frame_cmap:
    plt.figure(figsize=(7, 0.2))
    plt.imshow(gradient, aspect='auto', cmap="plasma")
    plt.gca().set_axis_off()
    plt.show()

out_plot = widgets.Output(layout=widgets.Layout(height='400px', width = '500px', justify_content='center'))


global ax
ax=None
global reps
reps = None
global reps_vis
reps_vis = False
reps2 = None


with out_plot:
    fig = plt.figure(figsize=(15, 10))
    ax = fig.add_subplot(1, 1, 1)
    scatter = ax.scatter(x[1], y[1],
            color = c[1, :], cmap="plasma",
            s=scat_size, alpha =0.5, )
    
    reps=ax.scatter(rep_x, rep_y, s=rep_size, color=rep_c, marker="*",
        edgecolors="black")
    reps.set_visible(reps_vis)
    ax.set_xlim([x_min, x_max])
    ax.set_ylim([y_min, y_max])
    ax.set_xlabel('C1')
    ax.set_ylabel('C2')
    plt.show()
    


def update(slider, x, y, c):
    frames = slider
    new_y = y[:frames]
    new_x = x[:frames]
    reps_vis = reps.get_visible()
    ax.clear()
    scatter = ax.scatter(new_x, new_y,
        c = c[:frames, :], cmap="plasma",
        s=10, alpha =0.5, )
        
    ax.scatter(x[frames], y[frames],
        color="red",
        s=100, alpha =1, marker="D",
        edgecolors="black")
    
    reps2 = ax.scatter(rep_x, rep_y, s=rep_size, color=rep_c, marker="*", edgecolors="black", visible=reps_vis)
    for i, txt in enumerate(reps_text):
        txt_tmp = ax.annotate(txt, (rep_x[i], rep_y[i]), visible=reps_vis, fontsize=20, color=rep_c[i], weight="bold")
        txt_tmp.set_path_effects([PathEffects.withStroke(linewidth=3, foreground='black')])
        
    #reps = tmp
    ax.set_xlim([x_min, x_max])
    ax.set_ylim([y_min, y_max])
    with out_plot:
        clear_output(wait=True)
        display(ax.figure)
    return slider

play = widgets.Play(
    value=1,
    min=0,
    max=traj_len,
    step=1,
    disabled=False
)


frame_slider = widgets.IntSlider(value=1,min=0,max=traj_len,step=10,description='show frames:\n',)

widgets.jslink((play, 'value'), (frame_slider, 'value'))
box_fslider = HBox([play, frame_slider])

slider = interactive(update, slider=frame_slider, x = fixed(x), y = fixed(y), c = fixed(c))


box = widgets.Checkbox(False, description='show cluster representatives')
def changed(b):
    value = b["owner"].value
    if value:
        reps_vis = True
    else:
        reps_vis = False
    reps.set_visible(reps_vis)
    update(frame_slider.value, x, y, c)    
box.observe(changed)


command_pallete = HBox([box_fslider, out_frame_cmap, box], layout=Layout(display='flex', flex_flow='column'))
nglwidget = engen.show_animated_traj()
nglwidget.clear_representations()
nglwidget.add_ball_and_stick()

interactive1 = VBox([command_pallete, out_plot])

#box_layout = Layout(display='flex', flex_flow='row', justify_content='center', align_items='center')

display(HBox([ VBox([]), VBox([interactive1, nglwidget]), VBox([])]))

HBox(children=(VBox(), VBox(children=(VBox(children=(HBox(children=(HBox(children=(Play(value=1, max=5001), In…