In [None]:
import json
import os
from typing import Any

from panda.utils.dyst_utils import (
    init_base_system_from_params,
    init_skew_system_from_params,
)

In [None]:
WORK_DIR = os.environ.get("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

dataset_names = ["improved/final_skew40", "improved/final_base40"]
split_names = ["test_zeroshot", "test_zeroshot_z5_z10"]
split_dir = [
    os.path.join(DATA_DIR, ds, sn) for ds in dataset_names for sn in split_names
]
split_dir

In [None]:
def init_system(params_data: dict[str, Any], system_type: str = "base"):
    init_fn = (
        init_base_system_from_params
        if system_type == "base"
        else init_skew_system_from_params
    )
    systems = {}
    for system_name, entries in params_data.items():
        for i, param_dict in enumerate(entries):
            try:
                if system_type == "skew":
                    driver_name, response_name = system_name.split("_")
                    sys = init_fn(driver_name, response_name, param_dict)
                else:
                    sys = init_fn(system_name, param_dict)
                systems[system_name] = sys
            except Exception as e:
                print(f"  Entry {i}: Failed to initialize - {e}")

    return systems


In [None]:
work_dir = os.environ.get("WORK", "/stor/work/AMDG_Gilpin_Summer2024")
base_params_file = (
    f"{work_dir}/data/improved/final_base40/parameters/test/filtered_params_dict.json"
)
skew_params_file = f"{work_dir}/data/improved/final_skew40/parameters/test_zeroshot/filtered_params_dict.json"

with open(base_params_file) as f:
    base_params = json.load(f)

with open(skew_params_file) as f:
    skew_params = json.load(f)

print(f"Loaded {len(base_params)} systems from base params")
print(f"Loaded {len(skew_params)} systems from skew params")

base_systems = init_system(base_params, "base")
skew_systems = init_system(skew_params, "skew")
all_systems = {**base_systems, **skew_systems}

print(len(all_systems))