# Evaluation 3.0

- speed
- GCP BQ support
- support for v4 model

## Notes
- Everything on tensorboard is convienient and fast af. 
- Need item level details down the line --> which BigQuery comes into play, but not so important until triangle model v4 is stable
- currently I am doing things in between, store data locally per model, then aggregate mean level statistic if batch run (varying h-param or multi runs). 
- I already almost coded everything in bit and pieces, just need to have a better integration

In [None]:
# %load_ext lab_black
import os
import altair as alt
import pandas as pd
import numpy as np
import tensorflow as tf
import meta, data_wrangling, modeling, metrics, evaluate, testcase_plots
from tqdm import tqdm
from importlib import reload
reload(evaluate)
reload(data_wrangling)
reload(metrics)

In [None]:

def init(code_name):
    cfg = meta.ModelConfig.from_json(os.path.join("models", code_name, "model_config.json"))
    model = modeling.MyModel(cfg)
    checkpoint = cfg.path["weights_checkpoint_fstring"].format(epoch=250)
    model.load_weights(checkpoint)
    data = data_wrangling.MyData()
    return cfg, model, data


cfg, model, data = init("triangle_with_strain")

## HS04 test cases 1: Overall performance
![Overall performance](references/hs04_fig9.png)

> The network was trained for 1.5 million word presentations. At the conclusion of training, the network produced the correct semantic representations for 97.3% of the items. For the other 2.7% of the words, it activated an average of 1.6 spurious features and failed to activate an average of 0.8 features. The model produced correct phonological representations for 99.2% of the words. On the remaining 0.8% of the words, it produced an average of 1.1 incorrect phonemes. Figure 9 depicts semantic and phonological accuracy over the course of training.

### Evaluation procedure:
1. Eval the entire training set in each:
- saved epoch
- output timesteps
2. Obtain both acc and sse during the evaluation
- pho accuracy = all slots correct phoneme
- sem accuracy = correct side of 0.5 (cosine doesn't make sense to me, since magnitude determines whether a node activate or not)
3. Plot hs04 fig 9 above

In [None]:
reload(evaluate)
reload(testcase_plots)
x = evaluate.TestSet(cfg, model)
df = x.eval('train_r1000', 'triangle')
mdf = testcase_plots.make_mean_df(df)
fig9 = testcase_plots.plot_hs04_fig9(mdf)
fig9.save(os.path.join(cfg.path['plot_folder'], 'fig9.html'))

## HS04 test cases 2: Freq x Cons in Taraban
![Freq x Cons](references/hs04_fig10.png)

> In the present model, the error computed at the end of processing was essentially zero for almost all items. This is because the model incorporates a phonological attractor, which tends to pull unit activities to their external values over time. In order to measure the difficulty the network had in reaching these states, we recorded the integral of the error over the course of processing the item from time step 4 to the final time step, 12 (the summation began with time step 4 because it takes four samples for information to flow to phonology from orthography via all routes).

In [None]:
# Convert Taraban to new testset package format (Run once)
# reload(data_wrangling)
# data = data_wrangling.MyData()
# taraban = data.create_testset_from_words(data.df_taraban.word, data.df_taraban.cond)
# data_wrangling.save_testset(taraban, 'dataset/testsets/taraban.pkl.gz')

In [None]:
reload(testcase_plots)
reload(evaluate)
df = x.eval('taraban', 'triangle')
mdf = testcase_plots.make_cond_mean_df(df)

# TODO: Refractorized testset specific post-processing
mdf = mdf.loc[mdf.cond.isin(['High-frequency exception', 'Regular control for High-frequency exception',
       'Low-frequency exception', 'Regular control for Low-frequency exception'])]
mdf['freq'] = mdf.cond.apply(lambda x: 'High' if x in ('High-frequency exception', 'Regular control for High-frequency exception') else 'Low')
mdf['reg'] = mdf.cond.apply(lambda x: 'Regular' if x.startswith('Regular') else 'Exception')

fig10 = testcase_plots.plot_hs04_fig10(mdf)
fig10.save(os.path.join(cfg.path['plot_folder'], 'fig10.html'))

In [None]:

mdf['freq'] = mdf.cond.apply(lambda x: 'high' if x.start_with('High-'))
mdf

In [None]:
taraban = data_wrangling.load_testset('dataset/testsets/grain.pkl.gz')

In [None]:
grain = data_wrangling.load_testset('dataset/testsets/grain.pkl.gz')

In [None]:
rep_names = ('ort', 'pho_large_grain', 'pho_small_grain', 'sem', 'pho')
for x in rep_names:
    grain[x] = tf.cast(grain[x], dtype=tf.float32)

In [None]:
data_wrangling.save_testset(testset=grain, file='dataset/testsets/grain.pkl.gz')

In [None]:
# Random sample 1000 items from train set (tmp fix for OOM issue)
# TODO eval on batch
s1000 = data.df_train.sample(1000).index
s1000 = data.create_testset_from_train_idx(s1000)
data_wrangling.save_testset(testset=s1000, file='dataset/testsets/train_r1000.pkl.gz')

In [None]:
df

# Examine one model

In [None]:
model.set_active_task("triangle")
y_pred = model([data.testsets["strain"]["ort"]] * cfg.n_timesteps)
y_true = {out: data.testsets["strain"][out] for out in ('pho', 'sem')}

pho_acc = metrics.PhoAccuracy()
pho_sse = metrics.SumSquaredError()
sem_acc = metrics.RightSideAccuracy()
sem_sse = metrics.SumSquaredError()

pho_acc.update_state(y_true['pho'], y_pred['pho'][-1])
pho_sse.update_state(y_true['pho'], y_pred['pho'][-1])
sem_acc.update_state(y_true['sem'], y_pred['sem'][-1])
sem_sse.update_state(y_true['sem'], y_pred['sem'][-1])
print(f"pho accuracy:{pho_acc.out.numpy():04f}, sem accuracy:{sem_acc.out.numpy():04f}")
print(f"pho sse:{pho_sse.out.numpy():04f}, sem sse:{sem_sse.out.numpy():04f}")

# Proto type testset implemetation manually
We need a vectorized map at these dimensions:
- model (1 for now)
- epoch (39)
- timestep (11)
- testset x cond (taraban, glushko, hs04 img)
- task (9, 5 main, 4 experimental)
- output (2 in triangle, otherwise 1)
- metrics (acc, sse, cosine) 

In [None]:
model.set_active_task("ort_sem")
y_pred = model([data.testsets["strain"]["ort"]] * cfg.n_timesteps)
y_true = data.testsets["strain"]["sem"]
sem_acc.update_state(y_true, y_pred['sem'][-1])
sem_sse.update_state(y_true, y_pred['sem'][-1])
print(f"sem accuracy:{sem_acc.out.numpy():04f}, sse:{sem_sse.out.numpy()}")

In [None]:
model.set_active_task("exp_ops")
y_pred = model([data.testsets["strain"]["ort"]] * cfg.n_timesteps)
y_true = data.testsets["strain"]["sem"]
sem_acc.update_state(y_true, y_pred['sem'][-1])
sem_sse.update_state(y_true, y_pred['sem'][-1])
print(f"sem accuracy:{sem_acc.out.numpy():04f}, sse:{sem_sse.out.numpy()}")

In [None]:
model.set_active_task("ort_pho")
y_pred = model([data.testsets["strain"]["ort"]] * cfg.n_timesteps)
y_true = data.testsets["strain"]["pho"]
pho_acc.update_state(y_true, y_pred['pho'][-1])
pho_sse.update_state(y_true, y_pred['pho'][-1])
print(f"pho accuracy:{pho_acc.out.numpy():04f}, pho sse:{pho_sse.out.numpy():04f}")

In [None]:
model.set_active_task("exp_osp")
y_pred = model([data.testsets["strain"]["ort"]] * cfg.n_timesteps)
y_true = data.testsets["strain"]["pho"]
pho_acc.update_state(y_true, y_pred['pho'][-1])
pho_sse.update_state(y_true, y_pred['pho'][-1])
print(f"pho accuracy:{pho_acc.out.numpy():04f}, pho sse:{pho_sse.out.numpy():04f}")


- timestep (11)
- testset x cond (taraban, glushko, hs04 img)
- task (9, 5 main, 4 experimental)
- output (2 in triangle, otherwise 1)
- metrics (acc, sse, cosine) 

In [None]:
x = TestSet(cfg, model)
x.eval('strain', 'ort_pho')


# Model level examine class (After eval)

In [None]:
class examine:
    
    def __init__(self, code_name, tf_root="/home/jupyter/tf"):

        try:
            # Fast load from disk
            csv_file = os.path.join(tf_root, 'models', code_name, 'eval', 'strain_mean_df.csv')
            self.df = pd.read_csv(csv_file)
        except:
            # Eval from scratch
            self.cfg = meta.ModelConfig.from_json(os.path.join(tf_root, 'models', code_name, 'model_config.json'))
            self.data = data_wrangling.MyData()
            self.model = modeling.HS04Model(self.cfg)
            self.model.build()
            self.test_strain = evaluate.EvalOral(self.cfg, self.model, self.data)
            self.df = self.test_strain.strain_mean_df

    def plot_op_strain(self):
        df = self.df

        @interact(
            use_y=['acc','sse','conditional_sse'],
            timetick=(1,12,1),
            y_max=(1, 20, 1)
            )
        def plot(use_y='acc', timetick=12, y_max=1):
            sdf = df.loc[(df.timetick==timetick)] 
            
            # Plot by condition
            plot_by_cond = alt.Chart(sdf).mark_line().encode(
                x=alt.X('epoch:Q', scale=alt.Scale(domain=(0, 100), clamp=True)),
                y=alt.Y(f"{use_y}:Q", scale=alt.Scale(domain=(0, y_max))),
                color='cond:N'
            )

            # Contrasts
            contrasts = {}
            contrasts['contrast_frequency'] = """(datum.HF_INC + datum.HF_CON - (datum.LF_INC + datum.LF_CON))/2""" 
            contrasts['contrast_consistency'] = """(datum.LF_CON + datum.HF_CON - (datum.LF_INC + datum.HF_INC))/2""" 

            def create_contrast_plot(name):
                return plot_by_cond.encode(y=alt.Y("difference:Q", scale=alt.Scale(domain=(-y_max, y_max)))
                    ).transform_pivot('cond', value=use_y, groupby=['epoch']
                    ).transform_calculate(difference = contrasts[name]
                    ).properties(title=name)

            return plot_by_cond | create_contrast_plot('contrast_frequency') | create_contrast_plot('contrast_consistency')

    def plot(self):
        """ Create an interactive plot for strain """
        df = self.df

        @interact(
            use_y=['acc','sse','conditional_sse'],
            task=['pho_sem', 'sem_pho', 'pho_pho', 'sem_sem'],
            timetick=(1,12,1),
            y_max=(1, 20, 1)
            )
        def plot(use_y='acc', timetick=12, task='pho_sem', y_max=1):
            sdf = df.loc[(df.timetick==timetick) & (df.task==task)] 
            
            # Plot by condition
            plot_by_cond = alt.Chart(sdf).mark_line().encode(
                x='epoch:Q',
                y=alt.Y(f"{use_y}:Q", scale=alt.Scale(domain=(0, y_max))),
                color='testset:N'
            )

            # Plot average
            plot_average = plot_by_cond.encode(y=alt.Y(f"mean({use_y}):Q", scale=alt.Scale(domain=(0, y_max))), color='task')
            plot_average += plot_average.mark_errorband()

            # Plot contrasts
            contrasts = {}
            contrasts['contrast_frequency'] = """(datum.strain_hf_con_hi + datum.strain_hf_con_li + datum.strain_hf_inc_hi + datum.strain_hf_inc_li - 
                (datum.strain_lf_con_hi + datum.strain_lf_con_li + datum.strain_lf_inc_hi + datum.strain_lf_inc_li))/4"""
            contrasts['contrast_consistency'] = """(datum.strain_hf_con_hi + datum.strain_hf_con_li + datum.strain_lf_con_hi + datum.strain_lf_con_li - 
                (datum.strain_hf_inc_hi + datum.strain_hf_inc_li + datum.strain_lf_inc_hi + datum.strain_lf_inc_li))/4"""
            contrasts['contrast_imageability'] = """(datum.strain_hf_con_hi + datum.strain_lf_con_hi + datum.strain_hf_inc_hi + datum.strain_lf_inc_hi - 
                (datum.strain_hf_con_li + datum.strain_lf_con_li + datum.strain_hf_inc_li + datum.strain_lf_inc_li))/4"""

            def create_contrast_plot(name):
                return plot_by_cond.encode(y=alt.Y("difference:Q", scale=alt.Scale(domain=(-y_max, y_max)))
                    ).transform_pivot('testset', value=use_y, groupby=['epoch']
                    ).transform_calculate(difference = contrasts[name]
                    ).properties(title=name)

            contrast_plots = alt.hconcat()
            for c in contrasts.keys():
                contrast_plots |= create_contrast_plot(c)


            return((plot_by_cond | plot_average) & contrast_plots)


In [None]:
tmp = examine('boo')
tmp.plot_op_strain()
# Full looks familiar... good interaction, fast learning overall (will slow down later, using a fast learning rate to save time on testing)

In [None]:
tmp = examine('op_half_stationary')
tmp.plot_op_strain()
# Learn slower... 
# HF_INC seems a tiny bit lower (more apparant in earlier ticks), maybe HF item has more CON O-P tokens?



In [None]:
tmp = examine('op_half_rank_noclip')
tmp.plot_op_strain()
# HF_INC further decrease --> CON > F



In [None]:
tmp = examine('op_half_rank_hc_30000')
tmp.plot_op_strain()
# Strong frequency effect



## Strain

In [None]:
# Half-pretrain (Chang 2019)

half_pretrain = examine("half_pretrain")
half_pretrain.plot()


In [None]:
# Chang 2019

chang_pretrain = examine("chang_pretrain")
chang_pretrain.plot()

In [None]:
# Full-pretrain 
full_pretrain = examine("full_pretrain")
full_pretrain.plot()