# Batch models analysis

## Purpose
- Overview of batch output
    - Count of model in each setting grid
    - Word accuracy at last epoch
    - Nonword accuracy at last epoch
- Use overview to control two plots
    - Development plot
    - W vs. NW plot
- Can "zoom-in" to specific run if needed

In [None]:
%load_ext lab_black
import os, json
import pandas as pd
import altair as alt
import numpy as np
from evaluate import make_df_wnw

alt.data_transformers.disable_max_rows()
alt.data_transformers.enable("default")

# Class for visualizing batch results (Working)

In [None]:
alt.data_transformers.disable_max_rows()
import pandas as pd
import altair as alt
from altair.expr import datum
import evaluate


class VisualizeBatchResults:
    """ Visulize the results in multiple runs in a batch
    0. Count model (Done)
    1. Dashboard (Done)
    2. Heatmap over epoch
    3. Overall X vs. Y
    4. Control param 1d (as line color) version of (3) 
    5. 3d version of (4)
    6. Standard deviation diagnostic (Done)
    """

    def __init__(self, batch_output_folder):
        self.batch_output_folder = batch_output_folder
        self.cfgs = pd.read_csv(self.batch_output_folder + "cfgs.csv", index_col=0)
        self.df = pd.read_csv(self.batch_output_folder + "bcdf.csv", index_col=0)
        self.n_rng = len(self.df.rng_seed.unique())
        self.varying_hparams = []  # will auto generate during check_cfgs_params
        self.check_cfgs_params(verbose=False)
        self.parse_dfs()

    def parse_dfs(self):
        """      
        df: raw data file 
        sdf: filtered df
        cdf: mean value in a h-param cell
        """
        # SDF: Selected useful data file
        sel = self.varying_hparams + [
            "batch_unique_setting_string",
            "code_name",
            "exp",
            "cond",
            "epoch",
            "acc",
        ]
        self.sdf = self.df.loc[(self.df.unit_time == self.df.unit_time.max()), sel]

        # CDF: Cell data file with the mean in each h-parameter setting cell
        self.cdf = self.sdf.groupby(
            self.varying_hparams
            + ["batch_unique_setting_string", "exp", "cond", "epoch"],
            as_index=False,
        ).mean()

    def check_cfgs_params(self, verbose=True):
        """
        Check the config dataframe has how many varying and static hyperparameters
        Also create a list of varying hyperparameter in self.varying_hparams
        """

        hide_verbose_cfgs = [
            "code_name",
            "uuid",
            "batch_unique_setting_string",
            "w_pp_noise",  # redundent to p_noise
            "w_pc_noise",  # redundent to p_noise
            "w_cp_noise",  # redundent to p_noise
            "bias_c_noise",  # redundent to p_noise
            "bias_p_noise",  # redundent to p_noise
        ]

        # Get varying h-parameters
        varying_h_params_blacklist = hide_verbose_cfgs.copy()
        varying_h_params_blacklist.append("rng_seed")

        if (
            len(self.varying_hparams) == 0
        ):  # Avoid duplication while rerunning check_cfgs_params
            for x in self.cfgs.columns:
                if (not x in varying_h_params_blacklist) and (
                    len(self.cfgs[x].unique()) > 1
                ):
                    self.varying_hparams.append(x)

        # Summarize varying and non-varying cfgs
        if verbose:
            print("===== Batch level varying hyperparams =====")
            for x in self.cfgs.columns:
                if not x in hide_verbose_cfgs:
                    if len(self.cfgs[x].unique()) > 1:
                        print("{}: {}".format(x, self.cfgs[x].unique()))

            print("\n===== Batch level static hyperparams =====")
            for x in self.cfgs.columns:
                if len(self.cfgs[x].unique()) == 1:
                    print(f"{x}: {self.cfgs[x].unique()[0]}")

    def plot_model_count_grid(self):
        import altair as alt

        base = (
            alt.Chart(self.cfgs)
            .mark_rect()
            .encode(
                x=self.varying_hparams[0] + ":O",
                color="count(code_name)",
                tooltip=["count(code_name)"],
            )
            .properties(title="Model counts")
        )

        if len(self.varying_hparams) == 1:
            plot_n = base

        elif len(self.varying_hparams) == 2:
            plot_n = base.encode(y=self.varying_hparams[1] + ":O")

        elif len(self.varying_hparams) == 3:
            plot_n = base.encode(
                y=self.varying_hparams[1] + ":O", column=self.varying_hparams[2] + ":O"
            )

        else:
            plot_n = "Too many (>3) varying parameter, cannot plot"

        return plot_n

    def plot_cell_sd_grid(self, conditions, dv="acc"):
        """
        Plot standard deviation of VARIABLE at given CONDITIONS at LAST TIME STEP
        df: pandas dataframe containing batch condition data file (bcdf)
        variates: a list of varying hyperparameters
        conditions (list): filter by conditions at df.cond
        dv: dependent variable to plot on heatmap (e.g., acc, sse)
        """
        # Select useful data
        sel_df = self.df.loc[
            (self.df.timestep == self.df.timestep.max())
            & self.df.cond.isin(conditions),
            self.varying_hparams + ["cond", "epoch", "rng_seed", dv],
        ]

        # Collapse condition
        mean_df = (
            sel_df.groupby(self.varying_hparams + ["epoch", "rng_seed"])
            .mean()
            .reset_index()
        )

        # Calculate standard deveiation in each cell
        plot_df = mean_df.groupby(self.varying_hparams + ["epoch"]).std().reset_index()

        # Plot
        base = (
            alt.Chart(plot_df)
            .mark_rect()
            .encode(
                x="epoch:O",
                y=self.varying_hparams[0] + ":O",
                color=dv,
                tooltip=self.varying_hparams + [dv],
            )
            .properties(
                title=f"Standard deviation from rng_seed within each cell in {conditions}"
            )
        )

        if len(self.varying_hparams) == 1:
            plot_sd = base

        elif len(self.varying_hparams) == 2:
            plot_sd = base.encode(row=self.varying_hparams[1] + ":O")

        elif len(self.varying_hparams) == 3:
            plot_sd = base.encode(
                row=self.varying_hparams[1] + ":O",
                column=self.varying_hparams[2] + ":O",
            )

        else:
            plot_sd = "Too many (>3) varying parameter, cannot plot"

        # Plot heatmap
        return plot_sd

    def sdplot_in_all_cond(self):
        all_sd_plot = alt.hconcat()
        for x in self.df.cond.unique():
            all_sd_plot &= self.plot_cell_sd_grid([x])

        return all_sd_plot

    def get_sort_order(self, hparam_name):
        return "ascending" if hparam_name == "p_noise" else "descending"

    def dashboard(
        self, exp=["strain", "grain"], word=["INC_HF"], nonword=["unambiguous"]
    ):

        # ODF: Overview data file with last epoch mean strain
        self.odf = (
            self.cdf.loc[
                (self.cdf.exp == "strain") & (self.cdf.epoch == self.cdf.epoch.max()),
                self.varying_hparams + ["batch_unique_setting_string", "acc"],
            ]
            .groupby(
                self.varying_hparams + ["batch_unique_setting_string"], as_index=False
            )
            .mean()
        )

        # CDF_wide: Wide format of cell data file, subset to word vs. nonword
        tmp = self.cdf.loc[self.cdf.cond.isin(word + nonword)].copy()

        tmp["wnw"] = tmp.cond.apply(lambda x: "word" if x in word else "nonword")

        self.cdf_wide = tmp.pivot_table(
            index=self.varying_hparams + ["batch_unique_setting_string", "epoch"],
            columns="wnw",
        ).reset_index()

        self.cdf_wide.columns = [
            "".join(c).strip() for c in self.cdf_wide.columns.values
        ]

        self.cdf_wide.rename(
            columns={"accword": "word_acc", "accnonword": "nonword_acc"}, inplace=True
        )

        self.cdf_wide["word_advantage"] = (
            self.cdf_wide.word_acc - self.cdf_wide.nonword_acc
        )

        # Start constructing dashboard

        select_control_space = alt.selection(
            type="multi",
            on="click",
            empty="none",
            fields=["batch_unique_setting_string"],
            init=[
                {"batch_unique_setting_string": self.odf.batch_unique_setting_string[0]}
            ],
        )

        select_dev_cond = alt.selection_multi(fields=["cond"], bind="legend")

        # Control space

        control_space = (
            alt.Chart(self.odf)
            .mark_rect()
            .encode(
                x=alt.X(
                    f"{self.varying_hparams[0]}:O",
                    sort=self.get_sort_order(self.varying_hparams[0]),
                ),
                y=alt.Y(
                    f"{self.varying_hparams[1]}:O",
                    sort=self.get_sort_order(self.varying_hparams[1]),
                ),
                column=alt.Column(
                    f"{self.varying_hparams[2]}:O",
                    sort=self.get_sort_order(self.varying_hparams[2]),
                ),
                color=alt.Color(
                    "acc", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
                ),
                opacity=alt.condition(
                    select_control_space, alt.value(1), alt.value(0.3)
                ),
                tooltip="acc",
            )
            .add_selection(select_control_space)
            .properties(title="Select a control parameter setting:")
        )

        # Development space
        self.cdf.sort_values(
            by=["batch_unique_setting_string", "cond", "epoch"], inplace=True
        )

        development_space = (
            alt.Chart(self.cdf.loc[self.cdf.exp.isin(exp)])
            .mark_line()
            .encode(
                y=alt.Y("acc:Q", scale=alt.Scale(domain=(0, 1))),
                x="epoch:Q",
                color="cond:N",
                tooltip=["epoch", "acc"],
                opacity=alt.condition(select_dev_cond, alt.value(1), alt.value(0)),
            )
            .transform_filter(select_control_space)
            .add_selection(select_dev_cond)
            .properties(
                title="Developmental space: Accuracy in each condition over epoch"
            )
        )

        # Performance space
        wnw_line = (
            alt.Chart(self.cdf_wide)
            .mark_line(color="black")
            .encode(
                y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
                x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
                tooltip=[
                    "batch_unique_setting_string",
                    "epoch",
                    "word_acc",
                    "nonword_acc",
                ],
            )
            .transform_filter(select_control_space)
        )

        diagonal = (
            alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
            .mark_line(color="#D3D3D3")
            .encode(
                x=alt.X("x", axis=alt.Axis(title="word")),
                y=alt.X("y", axis=alt.Axis(title="nonword")),
            )
        )

        performance_space = (diagonal + wnw_line).properties(
            title="Performance space: Nonword accuracy vs. Word accuracy"
        )

        # Merge dashboard
        return control_space & (development_space | performance_space)

In [None]:
batch_folder = "batch_eval/O2P_replication_set12/"
vis = VisualizeBatchResults(batch_folder)

### Create Strain and Grain / Taraban and Glushko dashboard

In [None]:
vis.dashboard(exp=["strain", "grain"], word=["INC_HF"], nonword=["unambiguous"]).save(
    "dash_sg.html"
)

vis.dashboard(
    exp=["taraban", "glushko"],
    word=["High-frequency regular-inconsistent"],
    nonword=["Regular"],
).save("dash_tg.html")

### View h-param grid

In [None]:
vis.check_cfgs_params()
vis.plot_model_count_grid()

### Check standard deviation for anomaly

# All runs at last time step

### no rng_seed aggregation

In [None]:
def plot_fig2(df):

    tmp = df.loc[df.timestep == df.timestep.max()]
    pdf = make_df_wnw(tmp, ["INC_HF"], ["unambiguous"])

    base = (
        alt.Chart(pdf)
        .mark_point()
        .encode(
            y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
            x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
            color=alt.Color(
                "epoch", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 100))
            ),
            opacity=alt.value(0.3),
            tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
        )
    )

    diagonal = (
        alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
        .mark_line(color="black")
        .encode(x="x", y="y")
    )
    return diagonal + base


plot_fig2(df).properties(title="All runs overlay (no aggregation)").save(
    batch_output_dir + "fig2.html"
)

### Plot figure 2 facet by regularization constant

### rng_seed aggregate

In [None]:
if n_rng > 1:
    df["code_name"] = df.batch_unique_setting_string
    plot_fig2(df).properties(
        title="All runs overlay (within setting cell aggregation)"
    ).save(batch_output_dir + "fig2_agg.html")

### Create df for plotting

In [None]:
if n_rng > 1:
    cfgs["code_name"] = cfgs.batch_unique_setting_string
    df["code_name"] = df.batch_unique_setting_string


pivotvars = variates + ["code_name", "epoch", "timestep", "unit_time", "exp", "cond"]
selvars = pivotvars + ["acc", "sse"]

df_cell_mean = df[selvars].pivot_table(index=pivotvars).reset_index()

# Select by condition and last time steps
df_sel = df_cell_mean.loc[
    (df_cell_mean.timestep == df_cell_mean.timestep.max())
    & df_cell_mean.cond.isin(["INC_HF", "unambiguous"])
]

# Make file
df_wnw = make_df_wnw(df_sel, ["INC_HF"], ["unambiguous"])
df_wnw["word_advantage"] = df_wnw.word_acc - df_wnw.nonword_acc
df_wnw = df_wnw.merge(cfgs)

selvars_wnw = variates + [
    "code_name",
    "epoch",
    "word_acc",
    "nonword_acc",
    "word_advantage",
]
df_wnw = df_wnw[selvars_wnw]

In [None]:
heatmap_last = (
    alt.Chart(df_wnw.loc[df_wnw.epoch == df_wnw.epoch.max()])
    .mark_rect()
    .encode(
        x="p_noise:O",
        y="hidden_units:O",
        row="learning_rate:O",
        color=alt.Color(
            "word_acc:Q", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
        ),
        tooltip=["word_acc", "nonword_acc", "word_advantage"] + variates,
    )
    .properties(title="Last epoch word accuracy")
)

In [None]:
heatmap_last

### Heatmap over epoch

In [None]:
def heatmap(df_wnw, prefix=""):

    hm_base = (
        alt.Chart(df_wnw)
        .mark_rect()
        .encode(
            x="p_noise:O",
            y="hidden_units:O",
            row="learning_rate:O",
            column="epoch:O",
            tooltip=["word_acc", "nonword_acc", "word_advantage"] + variates,
        )
    )

    # Word
    hm_word = hm_base.encode(
        color=alt.Color(
            "word_acc:Q", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
        )
    ).properties(title="Word accuracy")
    # Nonword
    hm_nonword = hm_base.encode(
        color=alt.Color(
            "nonword_acc:Q", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
        )
    ).properties(title="Nonword accuracy")

    # Word - Nonword
    hm_wordadvantage = hm_base.encode(
        color=alt.Color(
            "word_advantage:Q",
            scale=alt.Scale(scheme="redyellowgreen", domain=(-0.3, 0.3)),
        )
    ).properties(title="Word advantage")

    hm_word.save(batch_output_dir + prefix + "heatmap_word.html")
    hm_nonword.save(batch_output_dir + prefix + "heatmap_nonword.html")
    hm_wordadvantage.save(batch_output_dir + prefix + "heatmap_wordadvantage.html")

In [None]:
heatmap(df_wnw)

# Dashboard

In [None]:
alt.data_transformers.disable_max_rows()


def main_dashboard(df):

    sel_run = alt.selection(type="multi", on="click", fields=["code_name"])

    # df for overview
    df_ov = df[df.epoch == df.epoch.max()]

    # Shared master over-view
    overview = (
        alt.Chart(df_ov)
        .mark_rect()
        .encode(
            x="p_noise:O",
            y="hidden_units:O",
            row="learning_rate:O",
            color=alt.Color(
                "word_acc", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
            ),
            opacity=alt.condition(sel_run, alt.value(1), alt.value(0.1)),
            tooltip=[
                "code_name",
                "p_noise",
                "hidden_units",
                "learning_rate",
                "word_acc",
                "nonword_acc",
            ],
        )
        .add_selection(sel_run)
        .properties(title="Word accuracy at the end of training")
    )

    wnw_mdf = df.melt(
        id_vars=["code_name", "epoch"],
        value_vars=["word_acc", "nonword_acc"],
        var_name="wnw",
        value_name="acc",
    )

    plot_epoch = (
        alt.Chart(wnw_mdf)
        .mark_point(size=80)
        .encode(
            y=alt.Y("acc:Q", scale=alt.Scale(domain=(0, 1))),
            x="epoch:Q",
            color=alt.Color("code_name:N", legend=None),
            shape="wnw:N",
            opacity=alt.condition(sel_run, alt.value(1), alt.value(0)),
            tooltip=["code_name", "epoch", "acc"],
        )
        .add_selection(sel_run)
        .properties(title="Plot word and nonword accuracy by epoch")
    )

    wnw_line = (
        alt.Chart(df)
        .mark_line()
        .encode(
            y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
            x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
            color=alt.Color("code_name:N", legend=None),
            opacity=alt.condition(sel_run, alt.value(1), alt.value(0)),
            tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
        )
    )

    wnw_point = wnw_line.mark_point().encode(
        color=alt.Color("epoch", scale=alt.Scale(scheme="redyellowgreen"))
    )

    diagonal = (
        alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
        .mark_line(color="black")
        .encode(x="x", y="y")
    )

    wnw = diagonal + wnw_line + wnw_point

    wnw_interactive = wnw.add_selection(sel_run).properties(
        title="Word vs. Nonword accuracy at final time step"
    )

    ### Mini heatmap ###

    mini_wadv = (
        alt.Chart(df)
        .mark_rect()
        .encode(
            x="epoch:O",
            color=alt.Color(
                "word_advantage:Q",
                scale=alt.Scale(scheme="redyellowgreen", domain=(-0.3, 0.3)),
            ),
            opacity=alt.condition(sel_run, alt.value(1), alt.value(0)),
            tooltip=["word_acc", "nonword_acc", "word_advantage"] + variates,
        )
        .properties(title="Word - Nonword")
    )

    return overview | (plot_epoch & mini_wadv) | wnw_interactive


main_dashboard(df_wnw).save(batch_output_dir + "dashboard.html")

### Hyper-parameter effect plots

In [None]:
plot_pnoise = (
    alt.Chart()
    .mark_line()
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color("p_noise:O", scale=alt.Scale(scheme="reds")),
        tooltip=variates + ["epoch", "word_acc", "nonword_acc"],
    )
)

diagonal = (
    alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
    .mark_line(color="black")
    .encode(x="x", y="y")
)

p = alt.layer(diagonal + plot_pnoise, data=df_wnw).facet(
    row="hidden_units:O", column="learning_rate:O"
)

p.save(batch_output_dir + "Effect_pnoise.html")

In [None]:
plot_hidden = (
    alt.Chart()
    .mark_line()
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color("hidden_units:O", scale=alt.Scale(scheme="blues")),
        tooltip=variates + ["epoch", "word_acc", "nonword_acc"],
    )
)

h = alt.layer(diagonal + plot_hidden, data=df_wnw).facet(
    row="p_noise:O", column="learning_rate:O"
)

h.save(batch_output_dir + "Effect_hidden.html")

In [None]:
plot_lr = (
    alt.Chart()
    .mark_line()
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color(
            "learning_rate", type="ordinal", scale=alt.Scale(scheme="greens")
        ),
        tooltip=variates + ["epoch", "word_acc", "nonword_acc"],
    )
)

lr = alt.layer(diagonal + plot_lr, data=df_wnw).facet(
    row="hidden_units:O", column="p_noise:O"
)

lr.save(batch_output_dir + "Effect_lr.html")

SSE vs. P-noise

In [None]:
sdf = df.loc[
    (
        (df.timestep == df.timestep.max())
        & (df.cond.isin(["CON_HF", "CON_LF", "INC_HF", "INC_LF"]))
    ),
    ["code_name", "epoch", "p_noise", "acc", "sse"],
]

# Collapse condition
sdf = sdf.groupby(["code_name", "epoch"]).mean().reset_index()

In [None]:
alt.data_transformers.disable_max_rows()
c = (
    alt.Chart(sdf)
    .mark_point()
    .encode(
        y="acc",
        x=alt.X("sse", scale=alt.Scale(domain=(0, 8))),
        color="p_noise:N",
        opacity=alt.value(0.3),
    )
    .interactive()
)
c