In [None]:
import json
import os

In [None]:
WORK_DIR = os.environ.get("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
n_periods = 40

### Utils

In [None]:
def get_subdirs_from_splits(subset: str, splits: list[str]):
    subdirs = []
    for split_name in splits:
        split_dir = os.path.join(DATA_DIR, subset, split_name)
        if not os.path.exists(split_dir):
            continue
        subdirs.extend(
            [
                d
                for d in os.listdir(split_dir)
                if os.path.isdir(os.path.join(split_dir, d))
            ]
        )
    subdirs = set(subdirs)
    print(f"Found {len(subdirs)} subdirs in {subset} splits {splits}")
    return list(set(subdirs))

In [None]:
def find_substrings(
    setA: list[str],
    setB: list[str],
    setA_name: str = "setA",
    setB_name: str = "setB",
):
    matching_strings = {b for b in setB if any(a in b.split("_") for a in setA)}
    matching_patterns = {a for a in setA if any(a in b.split("_") for b in setB)}
    res = (list(matching_strings), list(matching_patterns))
    print(f"Found {len(res[0])} {setB_name} systems names that use {setA_name} systems")
    print(f"Found {len(res[1])} {setA_name} systems names in {setB_name}")
    return res

### Compare Skew and Base Subdirectories

In [None]:
skew_subset = f"improved/final_skew{n_periods}"
base_subset = f"improved/final_base{n_periods}"

In [None]:
skew_splits_train = ["train", "train_z5_z10", "train_z10_z15"]
skew_subdirs_train = get_subdirs_from_splits(skew_subset, skew_splits_train)

In [None]:
base_splits_train = ["train", "train_z5_z10", "train_z10_z15"]
base_subdirs_train = get_subdirs_from_splits(base_subset, base_splits_train)

In [None]:
# Step 1: Compare skew train vs base test_zeroshot
base_splits_test_zeroshot = [
    "test_zeroshot",
    "test_zeroshot_z5_z10",
    "test_zeroshot_z10_z15",
]
base_subdirs_test_zeroshot = get_subdirs_from_splits(
    base_subset, base_splits_test_zeroshot
)
res = find_substrings(
    base_subdirs_test_zeroshot,
    skew_subdirs_train,
    setA_name="base test_zeroshot",
    setB_name="skew train",
)
print(res[1])

In [None]:
# Step 4: Compare skew test_zeroshot vs base test_zeroshot
skew_splits_test_zeroshot = [
    "test_zeroshot",
    "test_zeroshot_z5_z10",
    "test_zeroshot_z10_z15",
]
skew_subdirs_test_zeroshot = get_subdirs_from_splits(
    skew_subset, skew_splits_test_zeroshot
)

res = find_substrings(
    base_subdirs_test_zeroshot,
    skew_subdirs_test_zeroshot,
    setA_name="base test_zeroshot",
    setB_name="skew test_zeroshot",
)
print(res[1])

In [None]:
res = find_substrings(
    skew_subdirs_train,
    skew_subdirs_test_zeroshot,
    setA_name="skew train",
    setB_name="skew test_zeroshot",
)
print(res[1])

In [None]:
res = find_substrings(
    base_subdirs_train,
    base_subdirs_test_zeroshot,
    setA_name="base train",
    setB_name="base test_zeroshot",
)
print(res[1])

### Test Scaling Law Splits

In [None]:
scalinglaw_dir = os.path.join(DATA_DIR, "improved/scalinglaw")

In [None]:
# list all subdirs in scalinglaw_dir
scalinglaw_subdirs = os.listdir(scalinglaw_dir)
split_subdirs = [
    "split_0-163_ic128",
    "split_163-327_ic64",
    "split_327-655_ic32",
    "split_655-1311_ic16",
    "split_1311-2622_ic8",
    "split_2622-5244_ic4",
    "split_5244-10489_ic2",
]
num_subdirs_lst = []
subdirs_lst = []
for i, split_subdir in enumerate(split_subdirs):
    # if split_subdir does not exist, raise an error
    split_subdir_train_path = os.path.join(scalinglaw_dir, split_subdir, "train")
    if not os.path.exists(split_subdir_train_path):
        raise ValueError(
            f"Split subdir {split_subdir} does not exist in {scalinglaw_dir}"
        )
    subdirs = os.listdir(split_subdir_train_path)
    num_files_in_all_subdirs = sum(
        [
            len(os.listdir(os.path.join(split_subdir_train_path, subdir)))
            for subdir in subdirs
        ]
    )
    print(
        f"Found {len(subdirs)} subdirs and {num_files_in_all_subdirs} files in {split_subdir_train_path}"
    )
    num_subdirs_lst.append(len(subdirs))
    subdirs_lst.append(subdirs)
    # if i == 0:
    #     continue
    # prev_subdirs = subdirs_lst[i - 1]
    # # number of items in common between subdirs and prev_subdirs
    # common_subdirs = set(subdirs).intersection(set(prev_subdirs))
    # print(
    #     f"--> Found {len(common_subdirs)} common subdirs in {split_subdir} and previous subdir split"
    # )

In [None]:
# Calculate and print the ratio of consecutive elements in num_subdirs_lst
print("Ratio of consecutive elements in num_subdirs_lst:")
for i in range(1, len(num_subdirs_lst)):
    ratio = num_subdirs_lst[i] / num_subdirs_lst[i - 1]
    print(f"{num_subdirs_lst[i]} / {num_subdirs_lst[i - 1]} = {ratio:.2f}")

In [None]:
split_params_dir = os.path.join(scalinglaw_dir, "params_disjoint")

split_params_system_names = {}
for i, split_subdir in enumerate(split_subdirs):
    split_name = split_subdir.split("_ic")[0]
    print(f"Split name: {split_name}")
    params_json = json.load(
        open(os.path.join(split_params_dir, f"params_dict_{split_name}.json"))
    )
    system_names = []
    for skew_name, skew_params in params_json.items():
        for pert_params in skew_params:
            sample_idx = pert_params["sample_idx"]
            system_name = f"{skew_name}-pp{sample_idx}"
            system_names.append(system_name)
    split_params_system_names[split_name] = system_names

In [None]:
split_names = list(split_params_system_names.keys())
for i, (split_name, system_names) in enumerate(split_params_system_names.items()):
    print(f"Split name: {split_name}, number of system names: {len(system_names)}")
    assert len(set(system_names)) == len(system_names), (
        f"Duplicate system names in {split_name}"
    )
    if i == 0:
        continue
    prev_system_names = split_params_system_names[split_names[i - 1]]
    common_system_names = set(prev_system_names).intersection(set(system_names))
    for prev_system_name in common_system_names:
        print(f"System name {prev_system_name} already exists in {split_name}")

In [None]:
baseline_system_names = split_params_system_names["split_0-163"]

for i, (split_name, system_names) in enumerate(split_params_system_names.items()):
    print(f"Split name: {split_name}, number of system names: {len(system_names)}")
    assert len(set(system_names)) == len(system_names), (
        f"Duplicate system names in {split_name}"
    )
    if i == 0:
        continue
    common_system_names = set(baseline_system_names).intersection(set(system_names))
    # for common_system_name in common_system_names:
    #     print(f"System name {common_system_name} already exists in {split_name}")
    print(
        f"Found {len(common_system_names)} common system names in {split_name}. Printing first 4:"
    )
    print(f"--> {list(common_system_names)[:4]}")