## Evaluate

In [None]:
import torch
import ipywidgets as widgets
import os 

from IPython.display import clear_output 
from evaluation import MotifDockEvaluation

global evaluation

def on_click(change):

    global evaluation
    evaluation = MotifDockEvaluation(path_widgets.value, device)
    evaluation.sampling(
        N            = sample_widgets.value, 
        timesteps    = timesteps_widgets.value,
        refix_steps  = refix_widgets.value,
        resampling   = resampling_widgets.value,
        model_name   = model_widgets.value,
        silvr_rate   = silvr_widgets.value
    )

  
path_widgets = widgets.Textarea(
                    value='../../MotifDock/data/valid/',
                    placeholder='Type something',
                    description='Path:',
                    disabled=False,
                    layout=widgets.Layout(height="100%", width="auto")
                )
  
sample_widgets = widgets.BoundedIntText(
                    value=32,
                    min=1,
                    max=32,
                    step=1,
                    description='Sample:',
                    disabled=False
                )

button_widgets = widgets.ToggleButton(
                    value=False,
                    description='run',
                    disabled=False,
                    button_style='', 
                    tooltip='Description',
                    icon='check'
                )

timesteps_widgets = widgets.IntSlider(
                    value=50,
                    min=0,
                    max=500,
                    step=10,
                    description='Timesteps:',
                    orientation='horizontal',
                )

resampling_widgets = widgets.IntSlider(
                    value=10,
                    min=1,
                    max=10,
                    step=1,
                    description='Resamlpe:',
                    orientation='horizontal',
                )

refix_widgets = widgets.IntSlider(
                    value = 100,
                    min = 0,
                    max = 500,
                    step = 1,
                    description='Refix steps',
                    orientation='horizontal',
                )

silvr_widgets = widgets.FloatSlider(
                    value = 0.01,
                    min = 0,
                    max = 1,
                    step = 0.01,
                    description='silvr_rate',
                    orientation='horizontal',
                )

model_widgets = widgets.Dropdown(
                options=[model for model in sorted(os.listdir("./model"))],
                description='Model:',
                disabled=False,
            )
 
device = 'cuda' if torch.cuda.is_available() else 'cpu'

button_widgets.observe(on_click, 'value')

clear_output()
display(path_widgets, sample_widgets, timesteps_widgets, refix_widgets, resampling_widgets, model_widgets, button_widgets, silvr_widgets)

### Plot

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets 

from pathlib import Path 


class ViewPlot:
    
    def __init__(self, npz_path):

        plt.rc('font', size=16)
        rmsd_dict = np.load(npz_path)
        self.rmsd_dict = rmsd_dict
        
        self.conf_rdkit_list = rmsd_dict['conf_rdkit_list']
        self.conf_fixed_list = rmsd_dict['conf_fixed_list']
        # self.conf_uncond_list = rmsd_dict['conf_uncond_list']
        self.conf_replacement_list = rmsd_dict['conf_replacement_list']
        self.conf_resampling_list = rmsd_dict['conf_resampling_list']

        self.total_original_list = rmsd_dict['total_original_list']
        self.total_rdkit_list = rmsd_dict['total_rdkit_list']
        self.total_fixed_list = rmsd_dict['total_fixed_list']
        # self.total_uncond_list = rmsd_dict['total_uncond_list']
        self.total_replacement_list = rmsd_dict['total_replacement_list']
        self.total_resampling_list = rmsd_dict['total_resampling_list']
    
    
    def draw_violinplot_all(self):
        
        fig = plt.figure(figsize=(30,10))
        label = ['rdkit', 'fixed', 'replacement', 'resampling']
        
        ax1 = fig.add_subplot(121) 
        ax2 = fig.add_subplot(122) 
        
        data_list = [self.conf_rdkit_list, self.conf_replacement_list, self.conf_fixed_list, self.conf_resampling_list]
        ax1.violinplot([data.reshape([-1]) for data in data_list], showmedians=True, quantiles=[[0.25, 0.75]]*len(data_list))
        ax1.set_xticks(range(1, len(label)+1), label)
        ax1.set_xlabel('Method')
        ax1.set_ylabel('Confermer RMSD')
        ax1.set_ylim([0,6])
        ax1.set_title('All sample')


        data_list = [self.total_rdkit_list, self.total_replacement_list, self.total_fixed_list, self.total_resampling_list]
        ax2.violinplot([data.reshape([-1]) for data in data_list], showmedians=True, quantiles=[[0.25, 0.75]]*len(data_list))
        ax2.set_xticks(range(1, len(label)+1), label)
        ax2.set_xlabel('Method')
        ax2.set_ylabel('Docking RMSD')
        ax2.set_ylim([0,6])
        ax2.set_title('All sample')
        
        
    def draw_violinplot_top_k(self, k:int=1):
        
        fig = plt.figure(figsize=(30,10))
        
        ax1 = fig.add_subplot(121) 
        ax2 = fig.add_subplot(122) 
        
        label = ['rdkit', 'unconditional', 'fixed', 'resampling']
        data_list = [self.conf_rdkit_list, self.conf_replacement_list, self.conf_fixed_list, self.conf_resampling_list]
        ax1.violinplot([data[:,:k].reshape([-1]) for data in data_list], showmedians=True, quantiles=[[0.25, 0.75]]*len(data_list))
        ax1.set_xticks(range(1, len(label)+1), label)
        ax1.set_xlabel('Method')
        ax1.set_ylabel('Confermer RMSD')
        ax1.set_ylim([0,6])
        ax1.set_title(f'Top-{k} sample')


        label = ['original', 'rdkit', 'unconditional', 'fixed', 'resampling']
        data_list = [self.total_original_list, self.total_rdkit_list, self.total_replacement_list, self.total_fixed_list, self.total_resampling_list]
        ax2.violinplot([data[:,:k].reshape([-1]) for data in data_list], showmedians=True, quantiles=[[0.25, 0.75]]*len(data_list))
        ax2.set_xticks(range(1, len(label)+1), label)
        ax2.set_xlabel('Method')
        ax2.set_ylabel('Docking RMSD')
        ax2.set_ylim([0,6])
        ax2.set_title(f'Top-{k} sample')

    
    def draw_scatterplot(self):
        
        plt.rc('font', size=25)
        
        plt.figure(figsize=(10, 10))
        plt.xlabel('rdkit')
        plt.ylabel('confgen (resampling)')
        plt.margins(x=0, y=0)
        plt.axhline(y=2, c='r', linestyle='--')
        plt.axvline(x=2, c='r', linestyle='--')
        plt.xlim([0,5])
        plt.ylim([0,5])
        plt.plot([0,10], [0,10], c='r')

        plt.scatter(self.conf_rdkit_list[:,0].reshape(([-1])), self.conf_resampling_list[:,0].reshape([-1]))


        plt.figure(figsize=(10,10))
        plt.margins(x=0, y=0)
        plt.axhline(y=2, c='r', linestyle='--')
        plt.axvline(x=2, c='r', linestyle='--')
        plt.plot([0,10], [0,10], c='r')

        plt.xlabel('rdkit best RMSD')
        plt.ylabel('ConfDiff best RMSD')
        plt.scatter(self.total_rdkit_list[:,0].reshape(([-1])), self.total_resampling_list[:,0].reshape([-1]))

    
    def rmsd_percentage(self, threshold:float=2.0):

        get_percent = lambda data_list: (data_list[:,0]<threshold).sum() / len(data_list)

        print('original', get_percent(self.total_original_list))
        print('rdkit', get_percent(self.total_rdkit_list))
        print('fixed', get_percent(self.total_fixed_list))
        print('replacement', get_percent(self.total_replacement_list))
        print('resampling', get_percent(self.total_resampling_list))
        

def on_click(change):

    print(npz_widgets.value)
    view = ViewPlot(npz_widgets.value)
    view.draw_violinplot_all()
    view.draw_violinplot_top_k(k=10)
    view.draw_violinplot_top_k(k=1)
    view.draw_scatterplot()
    view.rmsd_percentage(2.0)
    

npz_widgets = widgets.Dropdown(
                    options=list(Path('result').rglob('*.npz')),
                    description='npz',
                    disabled=False,
                )

plot_widgets = widgets.ToggleButton(
                    value=False,
                    description='plot',
                    disabled=False,
                    button_style='', # 'success', 'info', 'warning', 'danger' or ''
                    tooltip='Description',
                    icon='check' # (FontAwesome names without the `fa-` prefix)
                )

plot_widgets.observe(on_click, 'value')
display(npz_widgets, plot_widgets)