# Evaluation 3.0

- Speed
- GCP BQ support for batch run
- Support for v4 model output dict format

## Notes
- Everything on tensorboard is convienient and fast af. but kind of difficult to customize
- Need item level details down the line --> which BigQuery comes into play, but not so important until triangle model v4 is stable
- currently I am doing things in between, store data locally per model, then aggregate mean level statistic if batch run (varying h-param or multi runs). 
- I already almost coded everything in bit and pieces, just need to have a better integration

In [None]:
# %load_ext lab_black
import os
import altair as alt
import pandas as pd
import numpy as np
import tensorflow as tf
import meta, data_wrangling, modeling, evaluate, testcase_plots
from tqdm import tqdm
from importlib import reload

In [None]:
def init(code_name):
    cfg = meta.ModelConfig.from_json(
        os.path.join("models", code_name, "model_config.json")
    )
    model = modeling.MyModel(cfg)
    checkpoint = cfg.path["weights_checkpoint_fstring"].format(epoch=250)
    model.load_weights(checkpoint)
    data = data_wrangling.MyData()
    test = evaluate.TestSet(cfg, model)
    return cfg, model, data, test


cfg, model, data, test = init("triangle_high_time_res_4M_fix")

## HS04 test cases 1: Overall performance
![Overall performance](references/hs04_fig9.png)

> The network was trained for 1.5 million word presentations. At the conclusion of training, the network produced the correct semantic representations for 97.3% of the items. For the other 2.7% of the words, it activated an average of 1.6 spurious features and failed to activate an average of 0.8 features. The model produced correct phonological representations for 99.2% of the words. On the remaining 0.8% of the words, it produced an average of 1.1 incorrect phonemes. Figure 9 depicts semantic and phonological accuracy over the course of training.

### Evaluation procedure:
1. Eval the 1000 training set items (randomly sampled) in each:
- saved epoch
- output timesteps
2. Obtain both acc and sse during the evaluation
- pho accuracy = all slots correct phoneme
- sem accuracy = correct side of 0.5 (cosine doesn't make sense to me, since magnitude determines whether a node activate or not)
3. Plot hs04 fig 9 above

In [None]:
# Random sample 1000 items from train set (tmp fix for OOM issue)
# TODO eval on batch
# reload(data_wrangling)
# data = data_wrangling.MyData()
# s1000 = data.df_train.sample(1000).index
# s1000 = data.create_testset_from_train_idx(s1000)
# data_wrangling.save_testset(testset=s1000, file='dataset/testsets/train_r1000.pkl.gz')

In [None]:
def run_test1(test):
    df = test.eval("train_r1000", "triangle")
    mdf = testcase_plots.make_mean_df(df)
    fig9 = testcase_plots.plot_hs04_fig9(mdf, cfg.n_timesteps)
    fig9.save(os.path.join(test.cfg.path["plot_folder"], "test1.html"))


run_test1(test)

## HS04 test cases 2: Freq x Cons in Taraban
![Freq x Cons](references/hs04_fig10.png)

> In the present model, the error computed at the end of processing was essentially zero for almost all items. This is because the model incorporates a phonological attractor, which tends to pull unit activities to their external values over time. In order to measure the difficulty the network had in reaching these states, we recorded the integral of the error over the course of processing the item from time step 4 to the final time step, 12 (the summation began with time step 4 because it takes four samples for information to flow to phonology from orthography via all routes).

### IMPORTANT: SSE and ACC are the $integral$ of 4-12 ticks from this points onwards

In [None]:
# Convert Taraban to new testset package format (Run once)
# reload(data_wrangling)
# data = data_wrangling.MyData()
# taraban = data.create_testset_from_words(data.df_taraban.word, data.df_taraban.cond)
# data_wrangling.save_testset(taraban, 'dataset/testsets/taraban.pkl.gz')

In [None]:
def run_test2(test):
    df = test.eval("taraban", "triangle")
    mdf = testcase_plots.make_cond_mean_df(df)

    # TODO: Refractorized testset specific post-processing
    mdf = mdf.loc[
        mdf.cond.isin(
            [
                "High-frequency exception",
                "Regular control for High-frequency exception",
                "Low-frequency exception",
                "Regular control for Low-frequency exception",
            ]
        )
    ]
    mdf["freq"] = mdf.cond.apply(
        lambda x: "High"
        if x
        in ("High-frequency exception", "Regular control for High-frequency exception")
        else "Low"
    )
    mdf["reg"] = mdf.cond.apply(
        lambda x: "Regular" if x.startswith("Regular") else "Exception"
    )

    fig10 = testcase_plots.plot_hs04_fig10(mdf, tick_after=12)
    fig10.save(os.path.join(test.cfg.path["plot_folder"], "test2.html"))


run_test2(test)

## HS04 test cases 3: Glushko Nonword

> The model produced correct pronunciations for 93% of the nonwords derived from regular words and 84% of the ones derived from exception words. 

In [None]:
# glushko = {}
# glushko['item'] = data.df_glushko.word
# glushko['cond'] = data.df_glushko.cond
# glushko['ort'] = tf.constant(data.x_glushko, dtype=tf.float32)
# glushko['phoneme'] = [data.pho_glushko[word] for word in glushko['item']]
# glushko['pho'] = [data.y_glushko[word] for word in glushko['item']]
# glushko['sem'] = None
# data_wrangling.save_testset(testset=glushko, file='dataset/testsets/glushko.pkl.gz')


In [None]:
def run_test3(test):
    df = test.eval("glushko", "triangle")
    mdf = testcase_plots.make_cond_mean_df(df)
    test3 = testcase_plots.plot_conds(mdf, tick_after=12)
    test3.save(os.path.join(test.cfg.path["plot_folder"], "test3.html"))
    # test3


run_test3(test)

## HS04 test cases 4: IMG
![IMG](references/hs04_fig11.png)

> We first performed a median split of all items in the training set along the frequency dimension. All words were then categorized as regular or exception. Finally, we used the imageability norms of the Medical Research Council Psycholinguistic Database (Coltheart, 1981) to code all items in the training set that were in the database and did a median split on these items, categorizing them as high or low in imageability. We then identified words that fit each of the categories formed by crossing frequency, regularity, and imageability. The smallest number of items, 28, was obtained for the low-frequency, low-imageability irregular cell in the design. For each of the other cells in the design we randomly chose 28 of the qualifying words. All words were presented to the model, and its output was analyzed as in the simulation of frequency by consistency.

In [None]:
# # # Making a new test set that loosely follow HS04
# t = data.df_train[['word', 'wf']].copy()

# # Source surprisal
# df_suprisal = pd.read_csv('corpus/noam_surprisal.csv')
# op = dict(zip(df_suprisal.word.str.lower(), df_suprisal["uncond.surprisal"]))

# # Source MRC imageability
# df_img = pd.read_csv('corpus/MRC_img.csv', header=None, names=['word', 'img'])
# img = dict(zip(df_img.word.str.lower(), df_img.img))

# # Merge
# t['img'] = t.word.apply(lambda x: img[x] if x in img.keys() else None)
# t['op'] = t.word.apply(lambda x: op[x] if x in op.keys() else None)

# t = t.dropna()

# t['freq_gp'] = t.wf.apply(lambda x: 'hf' if x > t.wf.median() else 'lf')
# t['op_gp'] = t.op.apply(lambda x: 'hs' if x > t.op.median() else 'ls')
# t['img_gp'] = t.img.apply(lambda x: 'hi' if x > t.img.median() else 'li')
# t['cond'] = t.freq_gp + '_' + t.op_gp + '_' + t.img_gp

# print(f"Word frequency median cutoff = {t.wf.median()}")
# print(f"Imageability median cutoff = {t.img.median()}")
# print(f"OP surprisal median cutoff = {t.op.median()}")

# print("Count number of word in each condition:")
# print(t.groupby(['cond']).count().word)

# Packing
# hs04_img = data.create_testset_from_words(words=t.word, conds=t.cond)
# data_wrangling.save_testset(hs04_img, 'dataset/testsets/hs04_img.pkl.gz')

In [None]:
def run_test4(test):
    df = test.eval("hs04_img", "triangle")
    mdf = testcase_plots.make_cond_mean_df(df)
    mdf["fc"] = mdf.cond.apply(lambda x: x[:5])
    mdf["img"] = mdf.cond.apply(lambda x: x[-2:])
    test4 = testcase_plots.plot_hs04_fig11(mdf, tick_after=12)
    test4.save(os.path.join(test.cfg.path["plot_folder"], "test4.html"))


run_test4(test)

## HS04 test cases 5: Lesion
![Lesion](references/hs04_fig14.png)

> All words in the training set were presented to the trained reading model. To assess the time course of activity at a more fine grain, we ran the network for 4 units of whole time, as in training, but discretized over 48 samples, rather than 12, giving an integration constant of 0.083. The total input to target phonological units from the orth3phon path was summed at each sample. Similarly, the total input to target semantic units from orth3sem, from phon3sem, and from the semantic cleanup units was measured at each sample.

# Loopy loop loop...
We need to vectorized map functions at these levels, some might not be possible (variable size, ans)
- model (1 for now)
- epoch (39)
- timestep (11)
- testset x cond (taraban, glushko, hs04 img)
- task (9, 5 main, 4 experimental)
- output (2 in triangle, otherwise 1)
- maybe multiple answers 
- metrics (acc, sse, cosine) 

## SEM OUTPUT

In [None]:
def run_test5a(test):
    df_intact = test.eval("train_r1000", "triangle")
    df_os_lesion = test.eval("train_r1000", "exp_ops")
    df_ops_lesion = test.eval("train_r1000", "ort_sem")

    df = pd.concat([df_intact, df_os_lesion, df_ops_lesion])
    mdf = testcase_plots.make_mean_df(df)

    test5a = testcase_plots.plot_hs04_fig14(mdf, output="sem")
    test5a.save(os.path.join(test.cfg.path["plot_folder"], "test5_sem.html"))


run_test5a(test)

In [None]:
def run_test5b(test):
    df_intact = test.eval("train_r1000", "triangle")
    df_op_lesion = test.eval("train_r1000", "exp_osp")
    df_osp_lesion = test.eval("train_r1000", "ort_pho")

    df = pd.concat([df_intact, df_op_lesion, df_osp_lesion])
    mdf = testcase_plots.make_mean_df(df)
    testcase_plots.print_unique(mdf)

    test5b = testcase_plots.plot_hs04_fig14(mdf, output='pho')
    test5b.save(os.path.join(test.cfg.path["plot_folder"], "test5_pho.html"))


run_test5b(test)

In [None]:
testsets = ('pho_pho', 'pho_sem', 'sem_pho', 'sem_sem')
y = map(lambda x: test.eval('train_r1000', x), testsets)
pd.concat(y)

# Model level examine class (After eval)

In [None]:
class examine:
    
    def __init__(self, code_name, tf_root="/home/jupyter/tf"):

        try:
            # Fast load from disk
            csv_file = os.path.join(tf_root, 'models', code_name, 'eval', 'strain_mean_df.csv')
            self.df = pd.read_csv(csv_file)
        except:
            # Eval from scratch
            self.cfg = meta.ModelConfig.from_json(os.path.join(tf_root, 'models', code_name, 'model_config.json'))
            self.data = data_wrangling.MyData()
            self.model = modeling.HS04Model(self.cfg)
            self.model.build()
            self.test_strain = evaluate.EvalOral(self.cfg, self.model, self.data)
            self.df = self.test_strain.strain_mean_df

    def plot_op_strain(self):
        df = self.df

        @interact(
            use_y=['acc','sse','conditional_sse'],
            timetick=(1,12,1),
            y_max=(1, 20, 1)
            )
        def plot(use_y='acc', timetick=12, y_max=1):
            sdf = df.loc[(df.timetick==timetick)] 
            
            # Plot by condition
            plot_by_cond = alt.Chart(sdf).mark_line().encode(
                x=alt.X('epoch:Q', scale=alt.Scale(domain=(0, 100), clamp=True)),
                y=alt.Y(f"{use_y}:Q", scale=alt.Scale(domain=(0, y_max))),
                color='cond:N'
            )

            # Contrasts
            contrasts = {}
            contrasts['contrast_frequency'] = """(datum.HF_INC + datum.HF_CON - (datum.LF_INC + datum.LF_CON))/2""" 
            contrasts['contrast_consistency'] = """(datum.LF_CON + datum.HF_CON - (datum.LF_INC + datum.HF_INC))/2""" 

            def create_contrast_plot(name):
                return plot_by_cond.encode(y=alt.Y("difference:Q", scale=alt.Scale(domain=(-y_max, y_max)))
                    ).transform_pivot('cond', value=use_y, groupby=['epoch']
                    ).transform_calculate(difference = contrasts[name]
                    ).properties(title=name)

            return plot_by_cond | create_contrast_plot('contrast_frequency') | create_contrast_plot('contrast_consistency')

    def plot(self):
        """ Create an interactive plot for strain """
        df = self.df

        @interact(
            use_y=['acc','sse','conditional_sse'],
            task=['pho_sem', 'sem_pho', 'pho_pho', 'sem_sem'],
            timetick=(1,12,1),
            y_max=(1, 20, 1)
            )
        def plot(use_y='acc', timetick=12, task='pho_sem', y_max=1):
            sdf = df.loc[(df.timetick==timetick) & (df.task==task)] 
            
            # Plot by condition
            plot_by_cond = alt.Chart(sdf).mark_line().encode(
                x='epoch:Q',
                y=alt.Y(f"{use_y}:Q", scale=alt.Scale(domain=(0, y_max))),
                color='testset:N'
            )

            # Plot average
            plot_average = plot_by_cond.encode(y=alt.Y(f"mean({use_y}):Q", scale=alt.Scale(domain=(0, y_max))), color='task')
            plot_average += plot_average.mark_errorband()

            # Plot contrasts
            contrasts = {}
            contrasts['contrast_frequency'] = """(datum.strain_hf_con_hi + datum.strain_hf_con_li + datum.strain_hf_inc_hi + datum.strain_hf_inc_li - 
                (datum.strain_lf_con_hi + datum.strain_lf_con_li + datum.strain_lf_inc_hi + datum.strain_lf_inc_li))/4"""
            contrasts['contrast_consistency'] = """(datum.strain_hf_con_hi + datum.strain_hf_con_li + datum.strain_lf_con_hi + datum.strain_lf_con_li - 
                (datum.strain_hf_inc_hi + datum.strain_hf_inc_li + datum.strain_lf_inc_hi + datum.strain_lf_inc_li))/4"""
            contrasts['contrast_imageability'] = """(datum.strain_hf_con_hi + datum.strain_lf_con_hi + datum.strain_hf_inc_hi + datum.strain_lf_inc_hi - 
                (datum.strain_hf_con_li + datum.strain_lf_con_li + datum.strain_hf_inc_li + datum.strain_lf_inc_li))/4"""

            def create_contrast_plot(name):
                return plot_by_cond.encode(y=alt.Y("difference:Q", scale=alt.Scale(domain=(-y_max, y_max)))
                    ).transform_pivot('testset', value=use_y, groupby=['epoch']
                    ).transform_calculate(difference = contrasts[name]
                    ).properties(title=name)

            contrast_plots = alt.hconcat()
            for c in contrasts.keys():
                contrast_plots |= create_contrast_plot(c)


            return((plot_by_cond | plot_average) & contrast_plots)


In [None]:
tmp = examine('boo')
tmp.plot_op_strain()
# Full looks familiar... good interaction, fast learning overall (will slow down later, using a fast learning rate to save time on testing)

In [None]:
tmp = examine('op_half_stationary')
tmp.plot_op_strain()
# Learn slower... 
# HF_INC seems a tiny bit lower (more apparant in earlier ticks), maybe HF item has more CON O-P tokens?



In [None]:
tmp = examine('op_half_rank_noclip')
tmp.plot_op_strain()
# HF_INC further decrease --> CON > F



In [None]:
tmp = examine('op_half_rank_hc_30000')
tmp.plot_op_strain()
# Strong frequency effect



## Strain

In [None]:
# Half-pretrain (Chang 2019)

half_pretrain = examine("half_pretrain")
half_pretrain.plot()


In [None]:
# Chang 2019

chang_pretrain = examine("chang_pretrain")
chang_pretrain.plot()

In [None]:
# Full-pretrain 
full_pretrain = examine("full_pretrain")
full_pretrain.plot()