# Examine one model

In [None]:
%load_ext lab_black
import os
import altair as alt
from ipywidgets import interact
import pandas as pd
import meta, data_wrangling, modeling, metrics, evaluate

# meta.limit_gpu_memory_use(7000)

In [None]:
code_name = "full_pretrain"
tf_root = "/home/jupyter/tf"
cfg = meta.ModelConfig.from_json(os.path.join(tf_root, 'models', code_name, 'model_config.json'))
data = data_wrangling.MyData()
model = modeling.HS04Model(cfg)
model.build()

## Evaluate model

In [None]:
test = evaluate.EvalOral(cfg, model, data)
test.eval("strain")

## Strain

In [None]:
# Half-pretrain (Chang 2019)

df2 = test.strain_mean_df.copy()

def create_plot(df):
    @interact(
        use_y=['acc','sse','conditional_sse'],
        task=['pho_sem', 'sem_pho', 'pho_pho', 'sem_sem'],
        timetick=(1,12,1),
        y_max=(1, 20, 1)
        )
    def plot(use_y='acc', timetick=12, task='pho_sem', y_max=1):
        sdf = df.loc[(df.timetick==timetick) & (df.task==task)] 
        
        # Plot by condition
        plot_by_cond = alt.Chart(sdf).mark_line().encode(
            x='epoch:Q',
            y=alt.Y(f"{use_y}:Q", scale=alt.Scale(domain=(0, y_max))),
            color='testset:N'
        )

        # Plot average
        plot_average = plot_by_cond.encode(y=alt.Y(f"mean({use_y}):Q", scale=alt.Scale(domain=(0, y_max))), color='task')
        plot_average += plot_average.mark_errorband()

        # Plot contrasts
        contrasts = {}
        contrasts['contrast_frequency'] = """(datum.strain_hf_con_hi + datum.strain_hf_con_li + datum.strain_hf_inc_hi + datum.strain_hf_inc_li - 
            (datum.strain_lf_con_hi + datum.strain_lf_con_li + datum.strain_lf_inc_hi + datum.strain_lf_inc_li))/4"""
        contrasts['contrast_consistency'] = """(datum.strain_hf_con_hi + datum.strain_hf_con_li + datum.strain_lf_con_hi + datum.strain_lf_con_li - 
            (datum.strain_hf_inc_hi + datum.strain_hf_inc_li + datum.strain_lf_inc_hi + datum.strain_lf_inc_li))/4"""
        contrasts['contrast_imageability'] = """(datum.strain_hf_con_hi + datum.strain_lf_con_hi + datum.strain_hf_inc_hi + datum.strain_lf_inc_hi - 
            (datum.strain_hf_con_li + datum.strain_lf_con_li + datum.strain_hf_inc_li + datum.strain_lf_inc_li))/4"""

        def create_contrast_plot(name):
            return plot_by_cond.encode(y=alt.Y("difference:Q", scale=alt.Scale(domain=(-y_max, y_max)))
                ).transform_pivot('testset', value=use_y, groupby=['epoch']
                ).transform_calculate(difference = contrasts[name]
                ).properties(title=name)

        contrast_plots = alt.hconcat()
        for c in contrasts.keys():
            contrast_plots |= create_contrast_plot(c)


        return((plot_by_cond | plot_average) & contrast_plots)

In [None]:
# Train with full corpus
create_plot(df2)

In [None]:
# Trained with half corpus
create_plot(df)