# Batch models analysis

## Purpose
- Overview of batch output
    - Count of model in each setting grid
    - Word accuracy at last epoch
    - Nonword accuracy at last epoch
- Use overview to control two plots
    - Development plot
    - W vs. NW plot
- Can "zoom-in" to specific run if needed

In [None]:
%load_ext lab_black
import os, json
import pandas as pd
import altair as alt
import numpy as np
from meta import check_cfgs_params
from evaluate import make_df_wnw

alt.data_transformers.disable_max_rows()
alt.data_transformers.enable("default")

In [None]:
def glue_grain_idf(cfgs):
    from evaluate import vis
    from tqdm import tqdm
    from meta import model_cfg

    """
    Parse and Concat all condition level results from item level csvs
    And merge with cfg data (run level meta data) from cfgs
    cfgs: batch cfgs in dictionary format (The one we saved to disk, for running papermill)
    """

    evals_df = pd.DataFrame()
    cfgs_df = pd.DataFrame()

    for i in tqdm(range(len(cfgs))):

        # Extra cfg (with UUID) from actual saved cfg json
        this_model_cfg = model_cfg(cfgs[i]["model_folder"] + "model_config.json")
        cfgs_df = pd.concat(
            [cfgs_df, pd.DataFrame(this_model_cfg.to_dict(), index=[i])]
        )

        # Evaluate results
        this_eval = vis(cfgs[i]["model_folder"])
        evals_df = pd.concat([evals_df, this_eval.strain_i_hist], ignore_index=True)

    return cfgs_df, pd.merge(evals_df, cfgs_df, "left", "code_name")

## Read batch files

In [None]:
batch_name = "zer_test5_momentum0"
batch_output_dir = "batch_eval/{}/".format(batch_name)
cfgs = pd.read_csv(batch_output_dir + "cfgs.csv", index_col=0)
df = pd.read_csv(batch_output_dir + "bcdf.csv", index_col=0)
n_rng = len(df.rng_seed.unique())
check_cfgs_params(cfgs)

### Read Strain data

In [None]:
from tqdm import tqdm

idf = pd.DataFrame()
for i in tqdm(range(len(cfgs))):
    this_strain = pd.read_csv(f"models/{cfgs.code_name[i]}/result_strain_item.csv")
    idf = pd.concat([idf, this_strain], ignore_index=True)

idf = idf.merge(cfgs, on="code_name")

In [None]:
idf["zero_error_radius"] = idf.zero_error_radius.apply(lambda x: x if x > 0 else 0)

In [None]:
idf.columns

In [None]:
df = idf[
    [
        "code_name",
        "p_noise",
        "zero_error_radius",
        "epoch",
        "timestep",
        "word",
        "acc",
        "sse",
        "sse_slot1",
        "sse_slot2",
        "sse_slot3",
        "sse_slot4",
        "sse_slot5",
        "sse_slot6",
        "sse_slot7",
        "sse_slot8",
        "sse_slot9",
        "sse_slot10",
    ]
]

df = df.melt(
    id_vars=[
        "code_name",
        "p_noise",
        "zero_error_radius",
        "epoch",
        "timestep",
        "word",
        "acc",
        "sse",
    ]
)

### Slot based SSE in all items

In [None]:
df_allsse = (
    df.loc[(df.timestep == 1),]
    .pivot_table(index=["p_noise", "zero_error_radius", "epoch", "variable"])
    .reset_index()
)


alt.Chart(df_allsse).mark_line().encode(
    x="epoch:Q",
    y="value:Q",
    color="variable:N",
    row="p_noise:O",
    column="zero_error_radius:O",
).properties(title="Slot based SSE by ZER and Pnoise in all items")

### Accuracy 

In [None]:
df_ov = (
    idf[["code_name", "epoch", "zero_error_radius", "p_noise", "acc"]]
    .pivot_table(index=["code_name"])
    .reset_index()
)

alt.Chart(df_ov.round(2)).mark_rect().encode(
    x="p_noise:O", y="zero_error_radius:O", color="acc:Q", tooltip="acc"
).properties(title="Average accuracy by control params")

- ZER slow down learning
- p_noise and ZER subadditive interaction
    - consistent to p_noise as regularization stregegy interpretation

In [None]:
tmp = (
    idf[["batch_unique_setting_string", "epoch", "acc"]]
    .pivot_table(index=["batch_unique_setting_string", "epoch"])
    .reset_index()
)

alt.Chart(tmp).mark_line().encode(
    x="epoch", y="acc", color="batch_unique_setting_string"
).properties(title="ACC over epoch")

### Slot based SSE in Correct items

In [None]:
df_corsse = (
    df.loc[(df.timestep == 1) & (df.acc == 1),]
    .pivot_table(index=["p_noise", "zero_error_radius", "epoch", "variable"])
    .reset_index()
)

alt.Chart(df_corsse).mark_line().encode(
    x="epoch:Q",
    y="value:Q",
    color="variable:N",
    row="p_noise:O",
    column="zero_error_radius:O",
).properties(title="Slot based SSE by ZER and Pnoise in correct items")

- Slot 4 with max info
- Slot 10 with almost no info, but it start quite well??? 
    - I don't understand here... input 0, weight init small uniform -0.1, +0.1, out should be near 0.5 at the beginning? ZER 0.5 should not learn that much?
    
Follow up in later section: 
- look at the output at Slot 4 and Slot 10

### Slot based SSE in Incorrect items

In [None]:
df_incorsse = (
    df.loc[(df.timestep == 1) & (df.acc == 0),]
    .pivot_table(index=["p_noise", "zero_error_radius", "epoch", "variable"])
    .reset_index()
)


alt.Chart(df_incorsse).mark_line().encode(
    x="epoch:Q",
    y="value:Q",
    color="variable:N",
    row="p_noise:O",
    column="zero_error_radius:O",
).properties(title="Slot based SSE by ZER and Pnoise in incorrect items")

### Examine output directly
- Slot 4 (Max info)
- Slot 10 (Almost no info)

- Calculate Mean output at slot 4, and slot 10

In [None]:
idf.columns

In [None]:
df_output = (
    idf.loc[
        idf.timestep == 1,
        [
            "code_name",
            "zero_error_radius",
            "p_noise",
            "epoch",
            "acc",
            "mean_output_slot4",
            "mean_output_slot10",
        ],
    ]
    .pivot_table(index=["code_name", "zero_error_radius", "p_noise", "epoch"])
    .reset_index()
)

In [None]:
alt.Chart(df_output).mark_line(point=True).encode(
    x="epoch:Q",
    y="mean_output_slot10",
    color="code_name",
    tooltip=["p_noise", "zero_error_radius"],
).properties(title="Slot 10 output").interactive()

In [None]:
alt.Chart(df_output).mark_line(point=True).encode(
    x="epoch:Q",
    y="mean_output_slot4",
    color="code_name",
    tooltip=["p_noise", "zero_error_radius"],
).properties(title="Slot 4 output").interactive()