# generalization behavior of networks
here we evaluate (behaviorally) several networks on various datasets of different properties
and try to understand what training exposure leads to generalizability on what other kinds
of data distributions

In [None]:
from pathlib import Path
import typing
import yaml
import pandas as pd
import numpy as np
from tqdm.auto import tqdm

In [None]:
import workingmem
import workingmem.model
from workingmem.task.SIR import SIRDataset, SIRConfig, SIRTokenizer

from best_worst_models import best_worst

In [None]:
# with open("../configs/ref_n_back.yaml_25-11-19_sweep_dict.yaml", "+r") as f:
with open("../configs/ref_n_back.yaml_25-11-30_sweep_dict.yaml", "+r") as f:
    sweep_data = yaml.load(f, Loader=yaml.SafeLoader)

models = {}

In [None]:
sweep_data

In [None]:
for i, record in enumerate(sweep_data):
    sweep_id, td_prob, role_n_congruence = (
        record["sweep_id"],
        record["dataset.td_prob"],
        record["dataset.role_n_congruence"],
    )

    print(f"{sweep_id=}, {td_prob=}, {role_n_congruence=}")
    model_ckpts = Path(f"../model_checkpoints/{sweep_id}/").resolve()
    # count the number of checkpoints in the checkpoint directory
    num_models = len(list(model_ckpts.glob("*")))
    all_model_paths = [*best_worst(model_ckpts, num_models, verbose=True)][:num_models]
    models[sweep_id] = all_model_paths
    print()

sweep_data = pd.DataFrame(sweep_data)

In [None]:
models

In [None]:
sweep_data

In [None]:
# we want to load two datasets corresponding to
# - role_n_congruence = 0
# - td_prob = {0, 1}
datasets = {}
for td_prob in (0, 1):
    for role_n_congruence in (0,):
        # select the row where `td_prob` and `role_n_congruence` match the desired values
        row = sweep_data[
            (sweep_data["dataset.td_prob"] == td_prob)
            & (sweep_data["dataset.role_n_congruence"] == role_n_congruence)
        ].iloc[0]
        # get the corresponding `sweep_id`
        sweep_id = row["sweep_id"]
        # pick any model config for this sweep from the list of models
        model_conf = models[sweep_id][0]
        history_path = model_conf.from_pretrained / "history.yaml"
        history = yaml.safe_load(history_path.read_text())
        # look into its history for dataset path and load the dataset
        this_dataset = SIRDataset.from_path(
            history[-1]["dataset_path"], split="val", generate=False
        )
        # record the dataset in our dictionary
        datasets[td_prob, role_n_congruence] = this_dataset

datasets

In [None]:
(
    datasets[0, 0].tokenizer.encode("St reg_5 item_1").ids,
    datasets[1, 0].tokenizer.encode("St reg_5 item_1").ids,
)

In [None]:
records = []
for _, row in tqdm([*sweep_data.iterrows()]):
    src_td_prob = row["dataset.td_prob"]
    src_role_n_congruence = row["dataset.role_n_congruence"]
    # if src_td_prob in (0, 1) and src_role_n_congruence == 0:
    #     continue  # NOTE temporary exclusion
    sweep_id = row["sweep_id"]
    for model_conf in models[sweep_id]:
        print(model_conf)
        # load model
        wm_model = workingmem.model.ModelWrapper(model_conf)
        # evaluate this model on all datasets in `datasets`
        for (tgt_td_prob, tgt_role_n_congruence), dataset in datasets.items():
            out = wm_model.test(dataset=dataset)
            loss, acc = out["loss"], out["acc"]
            records += [
                {
                    "src_td_prob": src_td_prob,
                    "src_role_n_congruence": src_role_n_congruence,
                    "tgt_td_prob": tgt_td_prob,
                    "tgt_role_n_congruence": tgt_role_n_congruence,
                    "loss": loss,
                    "acc": acc,
                    "sweep_id": sweep_id,
                    "model": model_conf.from_pretrained.stem,
                }
            ]
        # break  # break if top-only

In [None]:
pd.DataFrame(records).to_csv(
    "td_role_cong_generalization_ignore_aware_20251203_all.csv", index=False
)