# Batch models analysis

## Purpose
- Overview of batch output
    - Count of model in each setting grid
    - Word accuracy at last epoch
    - Nonword accuracy at last epoch
- Use overview to control two plots
    - Development plot
    - W vs. NW plot
- Can "zoom-in" to specific run if needed

In [None]:
%load_ext lab_black
import os, json
import pandas as pd
import altair as alt
from meta import check_cfgs_params
from evaluate import make_df_wnw

alt.data_transformers.disable_max_rows()
alt.data_transformers.enable("default")

## Read files

In [None]:
batch_name = "O2P_setD"
batch_output_dir = "batch_eval/{}/".format(batch_name)
cfgs = pd.read_csv(batch_output_dir + "cfgs.csv", index_col=0)
df = pd.read_csv(batch_output_dir + "bcdf.csv", index_col=0)
check_cfgs_params(cfgs)

### Explicitly provide varying h-params after reviewing unique params

In [None]:
# variates = ["p_noise"]
variates = ["hidden_units", "learning_rate", "p_noise"]

# View h-param grid

In [None]:
plot_n = (
    alt.Chart(cfgs)
    .mark_rect()
    .encode(
        x="p_noise:O",
        y="hidden_units:O",
        row="learning_rate:O",
        color="count(code_name)",
        tooltip=["count(code_name)"],
    )
    .properties(title="Model counts")
)

plot_n

### Check standard deviation for anomaly

In [None]:
def plot_std(df, variates, conditions, dv):
    """
    Plot standard deviation of VARIABLE at given CONDITIONS at LAST TIME STEP
    df: pandas dataframe containing batch condition data file (bcdf)
    variates: a list of varying hyperparameters
    conditions: filter by conditions at df.cond
    dv: dependent variable to plot on heatmap (e.g., acc, sse)
    """
    # Select useful data
    sel_df = df.loc[
        (df.timestep == df.timestep.max()) & df.cond.isin(conditions),
        variates + ["cond", "epoch", "rng_seed", dv],
    ]

    # Collapse condition
    mean_df = sel_df.groupby(variates + ["epoch", "rng_seed"]).mean().reset_index()
    # Calculate standard deveiation in each cell
    plot_df = mean_df.groupby(variates + ["epoch"]).std().reset_index()

    # Plot heatmap
    return (
        alt.Chart(plot_df)
        .mark_rect()
        .encode(
            x="epoch:O",
            y="hidden_units:O",
            row="learning_rate:O",
            column="p_noise:O",
            color=alt.Color(dv, scale=alt.Scale(domain=[0, 0.2])),
            tooltip=variates + [dv],
        )
    )

# Examine rng_seed variation
### Plot all standard deviation if more than one rng_seed is found

In [None]:
n_rng = len(df.rng_seed.unique())

if n_rng > 1:

    # Strain SD
    sd_strain = plot_std(df, variates, ["INC_HF"], "acc")
    sd_strain.save(batch_output_dir + "stdev_strain_INCHF.html")

    # Grain SD
    sd_grain = plot_std(df, variates, ["unambiguous"], "acc")
    sd_grain.save(batch_output_dir + "stdev_grain_unambiguous.html")

    # Taraban SD
    taraban_w = [
        "High-frequency exception",
        "High-frequency regular-inconsistent",
        "Low-frequency exception",
        "Low-frequency regular-inconsistent",
        "Regular control for High-frequency exception",
        "Regular control for High-frequency regular-inconsistent",
        "Regular control for Low-frequency exception",
        "Regular control for Low-frequency regular-inconsistent",
    ]

    sd_taraban = plot_std(df, variates, taraban_w, "acc")
    sd_taraban.save(batch_output_dir + "stdev_taraban_all.html")

    # Glushko SD
    glushko_nw = ["Exception", "Regular"]
    sd_glushko = plot_std(df, variates, glushko_nw, "acc")
    sd_glushko.save(batch_output_dir + "stdev_glushko.html")

# All runs at last time step

### no rng_seed aggregation

In [None]:
def plot_fig2(df):

    tmp = df.loc[df.timestep == df.timestep.max()]
    pdf = make_df_wnw(tmp, ["INC_HF"], ["unambiguous"])

    base = (
        alt.Chart(pdf)
        .mark_point()
        .encode(
            y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
            x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
            color=alt.Color(
                "epoch", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 100))
            ),
            opacity=alt.value(0.3),
            tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
        )
    )

    diagonal = (
        alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
        .mark_line(color="black")
        .encode(x="x", y="y")
    )
    return diagonal + base


plot_fig2(df).properties(title="All runs overlay (no aggregation)")

### rng_seed aggregate

In [None]:
if n_rng > 1:
    df["code_name"] = df.batch_unique_setting_string
    plot_f2(df).properties(title="All runs overlay (within setting cell aggregation)")

# Word and NW heatmap

### Create df for plotting

In [None]:
if n_rng > 1:
    cfgs["code_name"] = cfgs.batch_unique_setting_string
    df["code_name"] = df.batch_unique_setting_string


pivotvars = variates + ["code_name", "epoch", "timestep", "unit_time", "exp", "cond"]
selvars = pivotvars + ["acc", "sse"]

df_cell_mean = df[selvars].pivot_table(index=pivotvars).reset_index()

# Select by condition and last time steps
df_sel = df_cell_mean.loc[
    (df_cell_mean.timestep == df_cell_mean.timestep.max())
    & df_cell_mean.cond.isin(["INC_HF", "unambiguous"])
]

# Make file
df_wnw = make_df_wnw(df_sel, ["INC_HF"], ["unambiguous"])
df_wnw["word_advantage"] = df_wnw.word_acc - df_wnw.nonword_acc
df_wnw = df_wnw.merge(cfgs)

selvars_wnw = variates + [
    "code_name",
    "epoch",
    "word_acc",
    "nonword_acc",
    "word_advantage",
]
df_wnw = df_wnw[selvars_wnw]

### Dashboard

In [None]:
len(last_epoch_df)

In [None]:
alt.data_transformers.enable("default")
alt.data_transformers.disable_max_rows()

# Selectors for interactions
sel_run = alt.selection(type="multi", on="click", fields=["code_name"])

########## Overview heatmap model selector ##########

last_epoch_df = df_wnw.loc[df_wnw.epoch == df_wnw.epoch.max()]

overview = (
    alt.Chart(last_epoch_df)
    .mark_rect()
    .encode(
        x="p_noise:O",
        y="hidden_units:O",
        row="learning_rate:O",
        color=alt.Color(
            "word_acc:Q", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
        ),
        opacity=alt.condition(sel_run, alt.value(1), alt.value(0.1)),
        tooltip=["word_acc:Q"] + variates,
    )
    .add_selection(sel_run)
    .properties(title="Word accuracy at last epoch")
)

In [None]:
########## Word vs. Nonword accuracy over epoch ##########
plot_epoch = (
    alt.Chart(df_sel)
    .mark_point()
    .encode(
        y=alt.Y("acc:Q", scale=alt.Scale(domain=(0, 1))),
        x="epoch:Q",
        color=alt.Color("cond:N"),
        opacity=alt.condition(sel_run, alt.value(1), alt.value(0)),
        tooltip=["code_name", "cond", "epoch", "acc"],
    )
    .properties(title="Development")
).add_selection(sel_run)


In [None]:
########## Word vs. Nonword plot with diagonal ##########

wnw_line = (
    alt.Chart(df_wnw)
    .mark_line()
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color("code_name", legend=None),
        opacity=alt.condition(sel_run, alt.value(1), alt.value(0)),
        tooltip=["code_name", "epoch", "word_acc", "nonword_acc"] + variates,
    )
)

wnw_point = wnw_base.mark_point().encode(
    color=alt.Color("epoch", scale=alt.Scale(scheme="redyellowgreen")),
)

diagonal = (
    alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
    .mark_line(color="black")
    .encode(x="x", y="y")
)

wnw = diagonal + wnw_base + wnw_point

wnw_interactive = wnw.properties(
    title="Word vs. Nonword accuracy at final time step"
).add_selection(sel_run)

dashboard = overview | plot_epoch | wnw_interactive
# dashboard.save(batch_output_dir + 'dashboard.html')
dashboard

### Word vs. Nonword

In [None]:
base = (
    alt.Chart(df_wnw)
    .mark_point()
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color(
            "epoch", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 100))
        ),
        opacity=alt.value(0.2),
        tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
    )
)

diagonal = (
    alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
    .mark_line(color="black")
    .encode(x="x", y="y")
)

plot = diagonal + base
plot
# plot.save(batch_output_dir + 'all_models_wnw.html')

In [None]:
base = (
    alt.Chart(df_wnw)
    .mark_rect()
    .encode(
        x="p_noise:O",
        y="hidden_units:O",
        row="learning_rate:O",
        column="epoch:O",
        tooltip=variates + ["word_acc", "nonword_acc"],
    )
)

w = base.encode(
    color=alt.Color("word_acc", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1)))
).properties(title="Word acc")

w.save(batch_output_dir + "word.html")
w

In [None]:
nw = base.encode(
    color=alt.Color(
        "nonword_acc", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
    )
).properties(title="Nonword acc")
nw.save(batch_output_dir + "nonword.html")
nw

In [None]:
adv = base.encode(
    color=alt.Color(
        "word_advantage", scale=alt.Scale(scheme="redyellowgreen", domain=(-0.3, 0.3))
    )
).properties(title="Word advantage")

adv.save(batch_output_dir + "wadv.html")
adv

In [None]:
plot_pnoise = (
    alt.Chart()
    .mark_line()
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color("p_noise", type="ordinal", scale=alt.Scale(scheme="reds")),
        tooltip=variates + ["epoch", "word_acc", "nonword_acc"],
    )
)

diagonal = (
    alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
    .mark_line(color="black")
    .encode(x="x", y="y")
)

p = alt.layer(diagonal + plot_pnoise, data=df_wnw).facet(
    row="hidden_units:O", column="learning_rate:O"
)

p.save(batch_output_dir + "p_noise.html")
p

In [None]:
plot_hidden = (
    alt.Chart()
    .mark_line()
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color(
            "hidden_units", type="ordinal", scale=alt.Scale(scheme="blues")
        ),
        tooltip=variates + ["epoch", "word_acc", "nonword_acc"],
    )
)

h = alt.layer(diagonal + plot_hidden, data=df_wnw).facet(
    row="p_noise:O", column="learning_rate:O"
)

h.save(batch_output_dir + "hidden.html")
h

In [None]:
plot_lr = (
    alt.Chart()
    .mark_line()
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color(
            "learning_rate", type="ordinal", scale=alt.Scale(scheme="greens")
        ),
        tooltip=variates + ["epoch", "word_acc", "nonword_acc"],
    )
)

lr = alt.layer(diagonal + plot_lr, data=df_wnw).facet(
    row="hidden_units:O", column="p_noise:O"
)

lr.save(batch_output_dir + "lr.html")
lr