# Examine Rueckl19 dataset

In [None]:
%load_ext lab_black
import pandas as pd
import altair as alt

alt.data_transformers.enable("default")
alt.data_transformers.disable_max_rows()

### Ingest, tidy

In [None]:
df = pd.read_csv("plotdf.csv", index_col=0)

# df['word_acc'] = df[[
#     'HF_CON_Accuracy', 'HF_INC_Accuracy', 'LF_CON_Accuracy', 'LF_INC_Accuracy'
# ]].mean(axis=1)
# df['nonword_acc'] = df[['NW_AMB_Accuracy', 'NW_UN_Accuracy']].mean(axis=1)

df["word_acc"] = df[["HF_INC_Accuracy"]].mean(axis=1)
df["nonword_acc"] = df[["NW_UN_Accuracy"]].mean(axis=1)

df.rename(
    columns={
        "ID": "code_name",
        "Trial.Scaled": "epoch",  # Trial scaled renamed to Epoch onward
        "Pnoise": "p_noise",  # group renamed to code_name onward
        "Hidden": "hidden_units",
        "Epsilon": "learning_rate",
        "PhoHid": "cleanup_units",
        "Classification": "group",
    },
    inplace=True,
)

df = df[
    [
        "code_name",
        "epoch",
        "hidden_units",
        "cleanup_units",
        "p_noise",
        "learning_rate",
        "word_acc",
        "nonword_acc",
        "group",
    ]
]

df = df[
    df.p_noise.isin([0, 2, 4, 8])
    & df.hidden_units.isin([50, 100, 250])
    & df.cleanup_units.isin([20])
    & df.learning_rate.isin([0.002, 0.004, 0.006, 0.008, 0.01])
]

### Are there more than one model in a unique set of setting? Yes...

In [None]:
def count_models(df):
    print("There are {} models in the datafile".format(len(df.code_name.unique())))

    dfm = df[
        ["code_name", "p_noise", "hidden_units", "learning_rate", "cleanup_units"]
    ].pivot_table(index="code_name")
    dfm["code_name"] = dfm.index

    pvt = dfm.pivot_table(
        index=["p_noise", "hidden_units", "learning_rate", "cleanup_units"],
        aggfunc="count",
        values="code_name",
    )

    pvt.reset_index(inplace=True)
    pvt.rename(columns={"code_name": "n"}, inplace=True)

    plot_n = (
        alt.Chart(pvt)
        .mark_rect()
        .encode(
            x="p_noise:O",
            y="hidden_units:O",
            row="learning_rate:O",
            column="cleanup_units:O",
            color="n:O",
            tooltip=["p_noise", "hidden_units", "cleanup_units", "learning_rate", "n"],
        )
        .properties(title="Model counts")
    )

    return plot_n


df_upper = df.loc[df.group == "Upper"]
df_mid = df.loc[df.group == "Mid"]
df_lower = df.loc[df.group == "Lower"]

count_by_group = (
    count_models(df)
    | count_models(df_upper)
    | count_models(df_mid)
    | count_models(df_lower)
)

count_by_group.save("count_model_selgridall.html")

count_by_group

# Variance within cell

In [None]:
w_acc_variance = (
    df[["learning_rate", "hidden_units", "p_noise", "epoch", "word_acc"]]
    .groupby(["learning_rate", "hidden_units", "p_noise", "epoch"])
    .std()
    .reset_index()
)

nw_acc_variance = (
    df[["learning_rate", "hidden_units", "p_noise", "epoch", "nonword_acc"]]
    .groupby(["learning_rate", "hidden_units", "p_noise", "epoch"])
    .std()
    .reset_index()
)

In [None]:
# Word accuracy variance

owv = (
    alt.Chart(w_acc_variance)
    .mark_rect()
    .encode(
        x="epoch:O",
        y="hidden_units:O",
        row="learning_rate:O",
        column="p_noise:O",
        color=alt.Color("word_acc:Q", scale=alt.Scale(domain=[0, 0.2])),
        tooltip=["p_noise", "hidden_units", "learning_rate", "word_acc"],
    )
)

# Nonword accuracy variance
onwv = (
    alt.Chart(nw_acc_variance)
    .mark_rect()
    .encode(
        x="epoch:O",
        y="hidden_units:O",
        row="learning_rate:O",
        column="p_noise:O",
        color=alt.Color("nonword_acc:Q", scale=alt.Scale(domain=[0, 0.2])),
        tooltip=["p_noise", "hidden_units", "learning_rate", "nonword_acc"],
    )
)

(owv & onwv).properties(title="old sims word and nonword variance")

# M270 variance

In [None]:
import pandas as pd
import altair as alt

df = pd.read_csv("../OSP/batch_eval/O2P_m270/bcdf.csv", index_col=0)

In [None]:
def plot_std(df, variates, conditions, dv):
    """
    Plot standard deviation of VARIABLE at given CONDITIONS at last time step
    """
    # Select useful data
    sel_df = df.loc[
        (df.timestep == df.timestep.max()) & df.cond.isin(conditions),
        variates + ["cond", "epoch", dv],
    ]

    # Calculate standard deveiation in each cell
    plot_df = sel_df.groupby(variates + ["epoch"]).std().reset_index()

    # Plot heatmap
    return (
        alt.Chart(plot_df)
        .mark_rect()
        .encode(
            x="epoch:O",
            y="hidden_units:O",
            row="learning_rate:O",
            column="p_noise:O",
            color=alt.Color(dv, scale=alt.Scale(domain=[0, 0.2])),
            tooltip=variates + [dv],
        )
    )

In [None]:
variates = ["hidden_units", "learning_rate", "p_noise"]
plot_std(df, variates, ["INC_HF"], "acc")

In [None]:
plot_std(df, variates, ["unambiguous"], "acc")

# Replicate Fig2.

In [None]:
len(df.code_name.unique())

In [None]:
sel_group = alt.selection(
    type="single",
    on="click",
    fields=["group"],
    bind=alt.binding_radio(options=["Upper", "Mid", "Lower"], name="Classification: "),
)

base = (
    alt.Chart(df)
    .mark_point()
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color(
            "epoch", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
        ),
        opacity=alt.condition(sel_group, alt.value(0.2), alt.value(0)),
        tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
    )
    .add_selection(sel_group).transform_filter(
    genre_select
)

diagonal = (
    alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
    .mark_line(color="black")
    .encode(x="x", y="y")
)

f2 = diagonal + base
f2.save("fig2.html")

### Group average plots

In [None]:
df.columns

In [None]:
dfg = df.pivot_table(index=["group", "epoch"]).reset_index()

sel_group = alt.selection(
    type="single",
    on="click",
    fields=["group"],
    bind=alt.binding_radio(options=["Upper", "Mid", "Lower"], name="Classification: "),
)

base = (
    alt.Chart(dfg)
    .mark_point()
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color(
            "epoch", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
        ),
        opacity=alt.condition(sel_group, alt.value(1), alt.value(0)),
        tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
    )
    .add_selection(sel_group)
)

diagonal = (
    alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
    .mark_line(color="black")
    .encode(x="x", y="y")
)

diagonal + base

### Aggregate cell heatmaps

In [None]:
dfc = df.pivot_table(
    index=["epoch", "hidden_units", "cleanup_units", "p_noise", "learning_rate"]
).reset_index()

dfc["word_advantage"] = dfc.word_acc - dfc.nonword_acc

In [None]:
sel_epoch = alt.selection(
    type="single",
    on="click",
    fields=["epoch"],
    bind=alt.binding_radio(options=list(dfc.epoch.unique()), name="Epoch: "),
)

w = (
    alt.Chart(dfc)
    .mark_rect()
    .encode(
        x="p_noise:O",
        y="hidden_units:O",
        row="learning_rate:O",
        column="epoch:O",
        color=alt.Color(
            "word_acc", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
        ),
        tooltip=[
            "p_noise",
            "hidden_units",
            "cleanup_units",
            "learning_rate",
            "word_acc",
            "nonword_acc",
        ],
    )
    .add_selection(sel_epoch)
    .transform_filter(sel_epoch)
    .properties(title="Word acc")
)

w.save("heatmap_word.html")

In [None]:
nw = (
    alt.Chart(dfc)
    .mark_rect()
    .encode(
        x="p_noise:O",
        y="hidden_units:O",
        row="learning_rate:O",
        column="epoch:O",
        color=alt.Color(
            "nonword_acc", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
        ),
        tooltip=[
            "p_noise",
            "hidden_units",
            "cleanup_units",
            "learning_rate",
            "word_acc",
            "nonword_acc",
        ],
    )
    .add_selection(sel_epoch)
    .transform_filter(sel_epoch)
    .properties(title="Nonword acc")
)

nw.save("heatmap_nonword.html")

In [None]:
wnw = (
    alt.Chart(dfc)
    .mark_rect()
    .encode(
        x="p_noise:O",
        y="hidden_units:O",
        row="learning_rate:O",
        column="epoch:O",
        color=alt.Color(
            "word_advantage",
            scale=alt.Scale(scheme="redyellowgreen", domain=(-0.3, 0.3)),
        ),
        tooltip=[
            "p_noise",
            "hidden_units",
            "cleanup_units",
            "learning_rate",
            "word_acc",
            "nonword_acc",
        ],
    )
    .add_selection(sel_epoch)
    .transform_filter(sel_epoch)
    .properties(title="Word advantage (Word - Nonword)")
)

wnw.save("heatmap_wnw.html")

In [None]:
# # Get model level mean word advantage sorting
# # Merge it back to cell level df

# dfm = dfc.pivot_table(
#     index=['hidden_units', 'cleanup_units', 'p_noise', 'learning_rate']
# ).reset_index()

# dfm['cell_id'] = dfm.index
# dfms = dfm.sort_values('word_advantage').reset_index(drop=True)
# dfms['sorted_adv'] = dfms.index
# dfms = dfms[['code_name', 'cell_id', 'sorted_adv']]

# dfc = dfc.merge(dfms, on='code_name')

In [None]:
dfc_speed = dfc.loc[
    dfc.epoch == 0.1,
]
dfc_speed.sample(5)

In [None]:
dfc_speed.loc[]

In [None]:
import statsmodels.api as sm
from statsmodels.formula.api import ols

model = ols(
    """
    word_acc ~ hidden_units + p_noise + cleanup_units + learning_rate +
                hidden_units*p_noise*cleanup_units*learning_rate +
                hidden_units*p_noise*cleanup_units 
    """,
    data=dfc_speed,
).fit()

sm.stats.anova_lm(model, typ=3).round(4)

### P-noise without aggregation

In [None]:
diagonal = alt.Chart(pd.DataFrame({
    'x': [0, 1],
    'y': [0, 1]
})).mark_line(color='black').encode(x='x', y='y')

plot_pnoise = alt.Chart().mark_line().encode(
    y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
    x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
    color=alt.Color("p_noise", type='ordinal', scale=alt.Scale(scheme="reds")),
    tooltip=[
        "epoch", "hidden_units", "cleanup_units", "p_noise", "learning_rate",
        "word_acc", "nonword_acc"
    ],
)

alt.layer(diagonal + plot_pnoise,
          data=dfc).facet(row="hidden_units:O", column="learning_rate:O")

### Hidden units effect

In [None]:
plot_hidden = alt.Chart().mark_line().encode(
    y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
    x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
    color=alt.Color(
        "hidden_units", type="ordinal", scale=alt.Scale(scheme="blues")
    ),
    tooltip=[
        "epoch", "hidden_units", "cleanup_units", "p_noise", "learning_rate",
        "word_acc", "nonword_acc"
    ],
)

alt.layer(diagonal + plot_hidden,
          data=dfc).facet(row="p_noise:O", column="learning_rate:O")

### Learning rate effect

In [None]:
plot_lr = alt.Chart().mark_line().encode(
    y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
    x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
    color=alt.Color(
        "learning_rate", type="ordinal", scale=alt.Scale(scheme="greens")
    ),
    tooltip=[
        "epoch", "hidden_units", "cleanup_units", "p_noise", "learning_rate",
        "word_acc", "nonword_acc"
    ],
)

alt.layer(diagonal + plot_lr,
          data=dfc).facet(row="hidden_units:O", column="p_noise:O")

### Dashboard

In [None]:
def main_dashboard(df):

    sel_run = alt.selection(type="multi", on="click", fields=["code_name"])

    # df for overview
    df_ov = df[df.epoch == df.epoch.max()]

    # Shared master over-view
    overview = (
        alt.Chart(df_ov)
        .mark_rect()
        .encode(
            x="p_noise:O",
            y="hidden_units:O",
            row="learning_rate:O",
            color=alt.Color(
                "word_acc", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
            ),
            opacity=alt.condition(sel_run, alt.value(1), alt.value(0.1)),
            tooltip=[
                "code_name",
                "p_noise",
                "hidden_units",
                "learning_rate",
                "word_acc",
                "nonword_acc",
            ],
        )
        .add_selection(sel_run)
        .properties(title="Word accuracy at the end of training")
    )

    wnw_mdf = df.melt(
        id_vars=["code_name", "epoch"],
        value_vars=["word_acc", "nonword_acc"],
        var_name="wnw",
        value_name="acc",
    )

    plot_epoch = (
        alt.Chart(wnw_mdf)
        .mark_point(size=80)
        .encode(
            y=alt.Y("acc:Q", scale=alt.Scale(domain=(0, 1))),
            x="epoch:Q",
            color=alt.Color("code_name:N", legend=None),
            shape="wnw:N",
            opacity=alt.condition(sel_run, alt.value(1), alt.value(0)),
            tooltip=["code_name", "epoch", "acc"],
        )
        .add_selection(sel_run)
        .properties(title="Plot word and nonword accuracy by epoch")
    )

    wnw_line = (
        alt.Chart(df)
        .mark_line()
        .encode(
            y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
            x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
            color=alt.Color("code_name:N", legend=None),
            opacity=alt.condition(sel_run, alt.value(1), alt.value(0)),
            tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
        )
    )

    wnw_point = wnw_line.mark_point().encode(
        color=alt.Color("epoch", scale=alt.Scale(scheme="redyellowgreen"))
    )

    diagonal = (
        alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
        .mark_line(color="black")
        .encode(x="x", y="y")
    )

    wnw = diagonal + wnw_line + wnw_point

    wnw_interactive = wnw.add_selection(sel_run).properties(
        title="Word vs. Nonword accuracy at final time step"
    )

    ### Mini heatmap ###

    mini_wadv = (
        alt.Chart(df)
        .mark_rect()
        .encode(
            x="epoch:O",
            color=alt.Color(
                "word_advantage:Q",
                scale=alt.Scale(scheme="redyellowgreen", domain=(-0.3, 0.3)),
            ),
            opacity=alt.condition(sel_run, alt.value(1), alt.value(0)),
            tooltip=["word_acc", "nonword_acc", "word_advantage"],
        )
        .properties(title="Word - Nonword")
    )

    return overview | (plot_epoch & mini_wadv) | wnw_interactive


main_plot = main_dashboard(dfc)
main_plot.save("dashboard_all.html")

In [None]:
main_plot