In [None]:
from pathlib import Path
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

In [None]:
basepath = Path(".")

pretrain_condition_map = {
    "70rfkjiv": "REF",
    "t4wviuxs": "REF",
    "bk1zoetn": "NBACK",
    "cjtwcbfr": "NBACK",
}

dfs = []
for p in [
    *basepath.glob("ref_n_back_finetuning_rnn_4to5/*.csv"),
    *basepath.glob("ignore_aware_ref_n_back_rnn_4_role_pretraining/*.csv"),
]:
    if any(["tdprob=1_" in str(p), "tdprob=0_" in str(p)]):
        if "lstm" in str(p):
            model_class = "lstm"
        elif "rnn" in str(p):
            model_class = "rnn"
        else:
            raise ValueError(f"Unknown model class in path: {p}")
        df = pd.read_csv(p)
        df["model_class"] = model_class
        # use the value in the 'from_pretrained' column and map it to the
        # pretrain_condition_map based on iterating over keys in the
        # pretrain_condition_map and checking of inclusion/presence of key in
        # the 'from_pretrained' column value
        df["pretrain_condition"] = df["from_pretrained"].apply(
            lambda x: next(
                (v for k, v in pretrain_condition_map.items() if k in str(x)),
                "NA",
            )
        )
        dfs += [df]

In [None]:
df = pd.concat(dfs, ignore_index=True)
df

In [None]:
view = df.copy()
# view = view[view.model_class == "lstm"]


g = sns.FacetGrid(
    view,
    col="td_prob",
    row="model_class",
    height=4,
    sharex=False,
    sharey=False,
)
g.map_dataframe(
    sns.lineplot,
    x="epoch",
    y="test_acc",
    # style="model_class",
    hue="pretrain_condition",
    errorbar="se",
    legend="full",
)
g.set_titles(row_template="model_class={row_name}", col_template="N_back? = {col_name}")
g.set(ylim=(0.3, 1))
g.axes[1, 0].set_ylim(0.7, 1)
g.axes[1, 1].set_ylim(0.7, 1)
g.set_axis_labels("epoch", "test_acc")
g.figure.suptitle(
    "Test Accuracy over epochs by pretraining on set size 4 tasks [test on set size 5 (exc. NA)]"
)
for ax in g.axes.flatten():
    ax.grid(True)

plt.legend()
plt.tight_layout()