# Examine Rueckl19 dataset

In [None]:
%load_ext lab_black
import pandas as pd
import altair as alt

alt.data_transformers.enable("default")
alt.data_transformers.disable_max_rows()

### Ingest, tidy

In [None]:
df = pd.read_csv("plotdf.csv", index_col=0)

# df['word_acc'] = df[[
#     'HF_CON_Accuracy', 'HF_INC_Accuracy', 'LF_CON_Accuracy', 'LF_INC_Accuracy'
# ]].mean(axis=1)
# df['nonword_acc'] = df[['NW_AMB_Accuracy', 'NW_UN_Accuracy']].mean(axis=1)

df["word_acc"] = df[["HF_INC_Accuracy"]].mean(axis=1)
df["nonword_acc"] = df[["NW_UN_Accuracy"]].mean(axis=1)

df.rename(
    columns={
        "ID": "code_name",
        "Trial.Scaled": "epoch",  # Trial scaled renamed to Epoch onward
        "Pnoise": "p_noise",  # group renamed to code_name onward
        "Hidden": "hidden_units",
        "Epsilon": "learning_rate",
        "PhoHid": "cleanup_units",
        "Classification": "group",
    },
    inplace=True,
)

df = df[
    [
        "code_name",
        "epoch",
        "hidden_units",
        "cleanup_units",
        "p_noise",
        "learning_rate",
        "word_acc",
        "nonword_acc",
        "group",
    ]
]

df = df.loc[df.code_name == 86052131]
df["word_advantage"] = df.word_acc - df.nonword_acc

# melt df for wnw plot
mdf = df.melt(
    id_vars=["code_name", "epoch"],
    value_vars=["word_acc", "nonword_acc"],
    var_name="wnw",
    value_name="acc",
)

### Plot the usuals

In [None]:
### Development ###

plot_dev = (
    alt.Chart(mdf)
    .mark_point(size=80)
    .encode(
        y=alt.Y("acc:Q", scale=alt.Scale(domain=(0, 1))),
        x="epoch:Q",
        color=alt.Color("code_name:N", legend=None),
        shape="wnw:N",
        tooltip=["code_name", "epoch", "acc"],
    )
    .properties(title="Plot word and nonword accuracy by epoch")
)


### Word vs. Nonword

wnw_line = (
    alt.Chart(df)
    .mark_line()
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color("code_name:N", legend=None),
        tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
    )
)

wnw_point = wnw_line.mark_point().encode(
    color=alt.Color("epoch", scale=alt.Scale(scheme="redyellowgreen"))
)

diagonal = (
    alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
    .mark_line(color="black")
    .encode(x="x", y="y")
)

wnw = diagonal + wnw_line + wnw_point


### Mini heatmap ###

mini_wadv = (
    alt.Chart(df)
    .mark_rect()
    .encode(
        x="epoch:O",
        color=alt.Color(
            "word_advantage:Q",
            scale=alt.Scale(scheme="redyellowgreen", domain=(-0.3, 0.3)),
        ),
        tooltip=["word_acc", "nonword_acc", "word_advantage"],
    )
    .properties(title="Word - Nonword")
)

main_plot = (plot_dev & mini_wadv) | wnw
main_plot