# Triple check vis

In [None]:
class vis():
    # Visualize single model
    # Which will parse item level data to condition level data
    # Then plot with Altair
    def __init__(self, model_folder, s_item_csv, g_item_csv):
        from evaluate import training_history, strain_eval, grain_eval, plot_development
        from data_wrangling import my_data
        import altair as alt
        
        self.model_folder = model_folder
        self.load_config()
                    
        self.read_eval_from_file(s_item_csv, g_item_csv)
        self.max_epoch = self.strain_i_hist['epoch'].max()

    def load_config(self):
        from meta import model_cfg
        self.cfg = model_cfg(self.model_folder + '/model_config.json', bypass_chk=True)
        
    def training_hist(self):
        self.t_hist = training_history(self.cfg.path_history_pickle)
        return self.t_hist.plot_all()
        
    def read_eval_from_file(self, s_item_csv, g_item_csv):
        self.strain_i_hist = pd.read_csv(self.model_folder + '/' + s_item_csv)
        self.grain_i_hist = pd.read_csv(self.model_folder + '/' + g_item_csv)
    
    # Condition level parsing
    def parse_strain_cond_df(self, cond):
        self.scdf = self.strain_i_hist[['code_name', 'epoch', 'sample_mil', 'timestep',
                                        'unit_time', cond, 'input_s', 'acc', 'sse']]
        self.scdf = self.scdf.groupby(['code_name', 'epoch', 'timestep', cond],
                                      as_index=False).mean() 
        self.scdf['cond'] = self.scdf[cond]
        self.scdf['exp'] = 'strain'
        
    def parse_grain_cond_df(self, cond):
        self.gcdf = self.grain_i_hist[['code_name', 'epoch', 'sample_mil', 'timestep',
                                       'unit_time', cond, 'input_s',
                                       'acc_acceptable', 'sse_acceptable',
                                       'acc_small_grain', 'sse_small_grain',
                                       'acc_large_grain', 'sse_large_grain']]
        self.gcdf = self.gcdf.rename(columns={'acc_acceptable':'acc', 'sse_acceptable':'sse'})
        self.gcdf = self.gcdf.groupby(['code_name', 'epoch', 'timestep', cond],
                                      as_index=False).mean()
        self.gcdf['cond'] = self.gcdf[cond]
        self.gcdf['exp'] = 'grain'
        
    def parse_cond_df(self, cond_strain='condition_pf', cond_grain='condition', output=None):
        self.parse_strain_cond_df(cond_strain)
        self.parse_grain_cond_df(cond_grain)
        self.cdf = pd.concat([self.scdf, self.gcdf], sort=False)

        if output is not None:
            self.cdf.to_csv(output, index=False)
            print('Saved file to {}'.format(output))
        
    # Visualization
    def plot_dev(self, y, exp=None, condition='cond', timestep=None):
        
        if timestep == None: timestep=self.cfg.n_timesteps
        timestep -= 1 # Reindex

        # Select data
        if exp is not None: 
            plot_df = self.cdf.loc[(self.cdf.exp==exp) & (self.cdf.timestep==timestep),]
        else:
            plot_df = self.cdf.loc[self.cdf.timestep==timestep,]

        # Plotting
        title = '{} at timestep {} / unit time {}'.format(y, timestep + 1, self.cfg.max_unit_time)
        sel = alt.selection(type='single', on='click', fields=[condition], empty='all')
        plot = alt.Chart(
                    plot_df
                ).mark_line(
                    point=True
                ).encode(
                    y=alt.Y(y, scale=alt.Scale(domain=(0, 1))),
                    x='epoch:Q',
                    color=condition,
                    opacity=alt.condition(sel, alt.value(1), alt.value(0)),
                    tooltip=['epoch', 'timestep', 'sample_mil', 'acc', 'sse']
                ).add_selection(sel
                ).interactive(
                ).properties(title=title)

        return plot
    
    def plot_dev_interactive(self, y, exp=None, condition='cond'):
        
        # Condition highlighter from legend
        select_cond = alt.selection(
            type='multi', on='click', fields=[condition], empty='all', bind="legend"
        )
        
        # Slider timestep filter
        slider_time = alt.binding_range(min=0, max=self.cfg.n_timesteps - 1, step=1)
        select_time = alt.selection_single(
            name="filter",
            fields=['timestep'],
            bind=slider_time,
            init={'timestep': self.cfg.n_timesteps - 1}
        )
        
        # Interactive development plot
        plot_dev = alt.Chart(self.cdf).mark_line(point=True).encode(
            y=alt.Y(y, scale=alt.Scale(domain=(0, 1))),
            x='epoch:Q',
            color=condition,
            opacity=alt.condition(select_cond, alt.value(1), alt.value(0.1)),
            tooltip=['epoch', 'timestep', 'sample_mil', 'acc', 'sse']
        ).add_selection(select_time, select_cond).transform_filter(select_time).properties(
            title='Development plot'
        )

        return plot_dev

    def plot_time(self, y, exp=None, condition='cond', epoch=None):  
        if epoch == None: epoch = self.max_epoch

        # Select data
        if exp is not None: 
            plot_df = self.cdf.loc[(self.cdf.exp==exp) & (self.cdf.epoch == epoch),]
        else:
            plot_df = self.cdf.loc[self.cdf.epoch == epoch,]

        # Plotting
        title = '{} at epoch {} '.format(y, epoch)
        sel = alt.selection(type='single', on='click', fields=[condition], empty='all')
        
        plot = alt.Chart(
                    plot_df
                ).mark_line(
                    point=True
                ).encode(
                    y=alt.Y(y, scale=alt.Scale(domain=(0, 1))),
                    x='unit_time:Q',
                    color=condition,
                    opacity=alt.condition(sel, alt.value(1), alt.value(0)),
                    tooltip=['epoch', 'timestep', 'sample_mil', 'acc', 'sse']
                ).add_selection(sel
                ).interactive(
                ).properties(title=title)

        return plot
    
    def plot_time_interactive(self, y, exp=None, condition='cond'):
        
        # Condition highlighter from legend
        select_cond = alt.selection(
            type='multi', on='click', fields=[condition], empty='all', bind="legend"
        )
            
        # Slider epoch filter
        slider_epoch = alt.binding_range(
            min=self.cfg.save_freq, max=self.cfg.nEpo, step=self.cfg.save_freq
        )
        
        select_epoch = alt.selection_single(
            name="filter",
            fields=['epoch'],
            bind=slider_epoch,
            init={'epoch': self.cfg.nEpo}
        )
        
        # Plot
        plot_time = alt.Chart(self.cdf).mark_line(point=True).encode(
            y=alt.Y(y, scale=alt.Scale(domain=(0, 1))),
            x='unit_time:Q',
            color=condition,
            opacity=alt.condition(select_cond, alt.value(1), alt.value(0.1)),
            tooltip=['epoch', 'timestep', 'sample_mil', 'acc', 'sse']
        ).add_selection(select_epoch, select_cond).transform_filter(select_epoch).properties(
            title='Interactive time plot',
        )
        
        return plot_time
    
    def plot_wnw(self, selected_cond):
    
        wnw_df = make_df_wnw(self.cdf, selected_cond=['INC_HF', 'ambiguous', 'unambiguous'])

        wnw_plot = (
            alt.Chart(wnw_df).mark_line(point=True).encode(
                y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
                x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
                color=alt.Color("epoch", scale=alt.Scale(scheme="redyellowgreen")),
                tooltip=["code_name", "word_acc", "nonword_acc"],
            ).properties(
                title="Word vs. Nonword accuracy at final time step"
            )
        )
        
        # Plot diagonal
        diagline = alt.Chart(pd.DataFrame({
            'x': [0, 1],
            'y': [0, 1]
        })).mark_line().encode(x=alt.X('x', axis=alt.Axis(labels=False)), 
                               y=alt.Y('y', axis=alt.Axis(labels=False)))

        wnw_with_diag = diagline + wnw_plot
        
        return wnw_with_diag
    
    def plots(self, mode, ys, cond_strain='condition_pf', cond_grain='condition'):
        # Mode = dev(d) / time(t)
        self.parse_cond_df(cond_strain, cond_grain)
        
        plots = alt.hconcat()
        for y in ys:
            if mode == 'd':
                plots |= self.plot_dev(y)
            elif mode == 't':
                plots |= self.plot_time(y, self.max_epoch)
            else:
                print('Use d for development plot, use t for time plot')
            
        return plots
        
    def export_result(self):
        self.parse_cond_df()
        return self.cdf.reset_index(drop=True)