In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import random
from itertools import combinations

import numpy as np
from dysts.metrics import estimate_kl_divergence  # type: ignore
from tqdm import tqdm

from panda.utils import (
    load_trajectory_from_arrow,
)

In [None]:
# apply_custom_style("../config/plotting.yaml")

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
skew_split_name = "improved/final_skew40/train"
skew_system_name = "Thomas_Sakarya"

In [None]:
base_split_name = "improved/final_base40/train"
driver_system_name = "Thomas"
response_system_name = "Sakarya"

In [None]:
filepaths = {}
for name, split, system in [
    ("skew", skew_split_name, skew_system_name),
    ("driver", base_split_name, driver_system_name),
    ("response", base_split_name, response_system_name),
]:
    subdir = os.path.join(DATA_DIR, split, system)
    filenames = sorted(os.listdir(subdir))
    filepaths[name] = [os.path.join(subdir, file) for file in filenames]
    print(f"{name} has {len(filepaths[name])} filepaths: {filepaths[name]}")

In [None]:
num_pairs = 5

# Create a random number generator for reproducibility and control
rseed = 42
rng = random.Random(rseed)

pairings = [
    ("skew", "skew", lambda: tuple(rng.sample(filepaths["skew"], 2))),
    (
        "skew",
        "driver",
        lambda: (rng.choice(filepaths["skew"]), rng.choice(filepaths["driver"])),
    ),
    (
        "skew",
        "response",
        lambda: (rng.choice(filepaths["skew"]), rng.choice(filepaths["response"])),
    ),
    (
        "driver",
        "response",
        lambda: (rng.choice(filepaths["driver"]), rng.choice(filepaths["response"])),
    ),
]

for name_a, name_b, pair_fn in pairings:
    pairs = [pair_fn() for _ in range(num_pairs)]
    print(f"Sampled {name_a}-{name_b} pairs:", pairs)
    kld_values = []
    for file_a, file_b in pairs:
        coords_a, _ = load_trajectory_from_arrow(file_a)
        coords_b, _ = load_trajectory_from_arrow(file_b)
        # print(f"{name_a.capitalize()} shape: {coords_a.shape}")
        # print(f"{name_b.capitalize()} shape: {coords_b.shape}")
        kld = estimate_kl_divergence(coords_a.T, coords_b.T)
        print(f"KLD: {kld}")
        kld_values.append(kld)
    if kld_values:
        mean_kld = np.mean(kld_values)
        std_kld = np.std(kld_values)
        print(f"Mean KLD for {name_a}-{name_b}: {mean_kld}")
        print(f"Std KLD for {name_a}-{name_b}: {std_kld}")

In [None]:
def sample_kld_pairs(pair_type, filepaths_by_dim, num_pairs, rng):
    pairs = []
    for dim, sysdict in tqdm(filepaths_by_dim.items(), desc="Sampling KLD pairs"):
        systems = list(sysdict)
        if pair_type == "intra":
            for fps in sysdict.values():
                if len(fps) >= 2:
                    pairs.extend(combinations(fps, 2))
        elif pair_type == "inter" and len(systems) > 1:
            for i, sys_a in enumerate(systems):
                for sys_b in systems[i + 1 :]:
                    pairs.extend((a, b) for a in sysdict[sys_a] for b in sysdict[sys_b])
    if len(pairs) < num_pairs:
        raise ValueError(
            f"Not enough unique {pair_type}-system same-dimension pairs ({len(pairs)}) to sample {num_pairs} pairs without repeats."
        )
    return rng.sample(pairs, num_pairs)


def compute_klds(pairs):
    klds = []
    for file_a, file_b in tqdm(pairs, desc="Computing KLDs"):
        coords_a, _ = load_trajectory_from_arrow(file_a)
        coords_b, _ = load_trajectory_from_arrow(file_b)
        # print(f"Shape of coords_a: {coords_a.shape}, coords_b: {coords_b.shape}")
        if coords_a.shape[0] != coords_b.shape[0]:
            print(
                f"Skipping pair due to mismatched dimensions: {coords_a.shape[0]} vs {coords_b.shape[0]}"
            )
            continue
        kld = estimate_kl_divergence(coords_a.T, coords_b.T)
        # print(f"KLD: {kld}")
        klds.append(kld)
    return klds

In [None]:
rseed = 42
rng = random.Random(rseed)
num_skew_systems_to_sample = 10
num_pairs = 10

# Sample skew systems and gather filepaths by dimension and system
skew_dir = os.path.join(DATA_DIR, skew_split_name)
skew_system_names = [
    d for d in os.listdir(skew_dir) if os.path.isdir(os.path.join(skew_dir, d))
]
sampled_skew_systems = rng.sample(
    skew_system_names, min(num_skew_systems_to_sample, len(skew_system_names))
)

skew_filepaths = {}
for system in sampled_skew_systems:
    subdir = os.path.join(skew_dir, system)
    for file in sorted(os.listdir(subdir)):
        filepath = os.path.join(subdir, file)
        coords, _ = load_trajectory_from_arrow(filepath)
        dim = coords.shape[0]
        skew_filepaths.setdefault(dim, {}).setdefault(system, []).append(filepath)

# Compute KLDs for intra- and inter-system pairs, store results in a dict
skew_kld_results = {}

for pair_type in ["intra", "inter"]:
    pairs = sample_kld_pairs(pair_type, skew_filepaths, num_pairs, rng)
    klds = compute_klds(pairs)
    if klds:
        skew_kld_results[pair_type] = {
            "pairs": pairs,
            "mean": np.mean(klds),
            "std": np.std(klds),
            "values": klds,
        }
    else:
        skew_kld_results[pair_type] = {
            "pairs": pairs,
            "mean": None,
            "std": None,
            "values": [],
        }

# Optionally print concise summary
for pair_type, res in skew_kld_results.items():
    print(
        f"{pair_type.capitalize()}-system skew pairs: mean KLD={res['mean']}, std={res['std']}, n={len(res['values'])}"
    )

In [None]:
num_base_systems_to_sample = 10
num_pairs = 10

# Sample base systems and gather filepaths by dimension and system
base_dir = os.path.join(DATA_DIR, base_split_name)
base_system_names = [
    d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
]
sampled_base_systems = rng.sample(
    base_system_names, min(num_base_systems_to_sample, len(base_system_names))
)

base_filepaths = {}
for system in sampled_base_systems:
    subdir = os.path.join(base_dir, system)
    for file in sorted(os.listdir(subdir)):
        filepath = os.path.join(subdir, file)
        coords, _ = load_trajectory_from_arrow(filepath)
        dim = coords.shape[0]
        base_filepaths.setdefault(dim, {}).setdefault(system, []).append(filepath)

# Compute KLDs for intra- and inter-system pairs, store results in a dict
base_kld_results = {}

for pair_type in ["intra", "inter"]:
    pairs = sample_kld_pairs(pair_type, base_filepaths, num_pairs, rng)
    klds = compute_klds(pairs)
    if klds:
        base_kld_results[pair_type] = {
            "pairs": pairs,
            "mean": np.mean(klds),
            "std": np.std(klds),
            "values": klds,
        }
    else:
        base_kld_results[pair_type] = {
            "pairs": pairs,
            "mean": None,
            "std": None,
            "values": [],
        }

# Optionally print concise summary
for pair_type, res in base_kld_results.items():
    print(
        f"{pair_type.capitalize()}-system base pairs: mean KLD={res['mean']}, std={res['std']}, n={len(res['values'])}"
    )

In [None]:
# base_kld_results["intra"]

In [None]:
num_skew_systems_to_sample = 10
num_pairs = 10

# --- Efficiently gather filepaths for base and skew systems by dimension and system name ---


def gather_filepaths_by_dim_and_system(root_dir, system_names, desc=None):
    """Return {dim: {system: [filepaths]}} for given systems in root_dir."""
    filepaths = {}
    iterator = system_names
    if desc is not None:
        iterator = tqdm(system_names, desc=desc)
    for system in iterator:
        subdir = os.path.join(root_dir, system)
        for file in sorted(os.listdir(subdir)):
            filepath = os.path.join(subdir, file)
            # Only load header to get shape, not full data
            coords, _ = load_trajectory_from_arrow(filepath)
            dim = coords.shape[0]
            filepaths.setdefault(dim, {}).setdefault(system, []).append(filepath)
    return filepaths


# Sample skew systems
skew_dir = os.path.join(DATA_DIR, skew_split_name)
skew_system_names = [
    d for d in os.listdir(skew_dir) if os.path.isdir(os.path.join(skew_dir, d))
]
sampled_skew_systems = rng.sample(
    skew_system_names, min(num_skew_systems_to_sample, len(skew_system_names))
)

# Gather filepaths for base and skew systems with progress bars
base_dir = os.path.join(DATA_DIR, base_split_name)
base_system_names = [
    d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
]
base_filepaths = gather_filepaths_by_dim_and_system(
    base_dir, base_system_names, desc="Base systems"
)
skew_filepaths = gather_filepaths_by_dim_and_system(
    skew_dir, sampled_skew_systems, desc="Skew systems"
)


# Helper to parse driver and response from skew system name (assumes "driver_response")
def parse_driver_response(skew_name):
    if "_" in skew_name:
        driver, response = skew_name.split("_", 1)
        return driver, response
    else:
        return skew_name, None


# Efficient KLD pair sampling for skew-base, for "intra" and "inter"
def sample_skew_base_kld_pairs(
    pair_type, skew_filepaths, base_filepaths, num_pairs, rng
):
    pairs = []
    for dim, skew_dim_dict in skew_filepaths.items():
        base_dim_dict = base_filepaths.get(dim)
        if not base_dim_dict:
            continue
        base_systems = set(base_dim_dict.keys())
        for skew_name, skew_files in skew_dim_dict.items():
            driver, response = parse_driver_response(skew_name)
            if pair_type == "intra":
                relevant_bases = {s for s in (driver, response) if s in base_systems}
            elif pair_type == "inter":
                relevant_bases = base_systems - {driver, response}
            else:
                continue
            for base in relevant_bases:
                base_files = base_dim_dict[base]
                n = min(num_pairs, len(skew_files), len(base_files))
                if n == 0:
                    continue
                # Use random.sample only if needed
                if len(skew_files) == n:
                    skew_sample = skew_files
                else:
                    skew_sample = rng.sample(skew_files, n)
                if len(base_files) == n:
                    base_sample = base_files
                else:
                    base_sample = rng.sample(base_files, n)
                pairs.extend(zip(skew_sample, base_sample))
    # Subsample if too many pairs
    if len(pairs) > num_pairs:
        pairs = rng.sample(pairs, num_pairs)
    return pairs


skew_base_kld_results = {}

for pair_type in ["intra", "inter"]:
    print(f"Computing KLDs for {pair_type}-system skew-base pairs...")
    pairs = sample_skew_base_kld_pairs(
        pair_type, skew_filepaths, base_filepaths, num_pairs, rng
    )
    klds = []
    if pairs:
        for pair in tqdm(pairs, desc=f"KLD {pair_type} pairs", leave=False):
            kld = compute_klds([pair])
            if kld:
                klds.extend(kld)
    if klds:
        skew_base_kld_results[pair_type] = {
            "pairs": pairs,
            "mean": np.mean(klds),
            "std": np.std(klds),
            "values": klds,
        }
    else:
        skew_base_kld_results[pair_type] = {
            "pairs": pairs,
            "mean": None,
            "std": None,
            "values": [],
        }

# Print concise summary
for pair_type, res in skew_base_kld_results.items():
    print(
        f"{pair_type.capitalize()}-system skew-base pairs: mean KLD={res['mean']}, std={res['std']}, n={len(res['values'])}"
    )

In [None]:
num_skew_systems_to_sample = 10
num_pairs = 3


def gather_filepaths_by_dim_and_system(root_dir, system_names, desc=None):
    """Return {dim: {system: [filepaths]}} for given systems in root_dir."""
    filepaths = {}
    iterator = tqdm(system_names, desc=desc) if desc else system_names
    for system in iterator:
        subdir = os.path.join(root_dir, system)
        for file in sorted(os.listdir(subdir)):
            coords, _ = load_trajectory_from_arrow(os.path.join(subdir, file))
            dim = coords.shape[0]
            filepaths.setdefault(dim, {}).setdefault(system, []).append(
                os.path.join(subdir, file)
            )
    return filepaths


# Sample skew systems
skew_dir = os.path.join(DATA_DIR, skew_split_name)
skew_system_names = [
    d for d in os.listdir(skew_dir) if os.path.isdir(os.path.join(skew_dir, d))
]
sampled_skew_systems = rng.sample(
    skew_system_names, min(num_skew_systems_to_sample, len(skew_system_names))
)

# Gather filepaths for base and skew systems with progress bars
base_dir = os.path.join(DATA_DIR, base_split_name)
base_system_names = [
    d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
]
base_filepaths = gather_filepaths_by_dim_and_system(
    base_dir, base_system_names, desc="Base systems"
)
skew_filepaths = gather_filepaths_by_dim_and_system(
    skew_dir, sampled_skew_systems, desc="Skew systems"
)


def parse_driver_response(skew_name):
    return tuple(skew_name.split("_", 1)) if "_" in skew_name else (skew_name, None)


def sample_skew_vs_base_pairs(skew_filepaths, base_filepaths, which, num_pairs, rng):
    """
    which: "driver", "response", "base", or "skew"
    For "driver" or "response", pairs skew system with its driver/response base system.
    For "base", pairs skew system with a base system that is neither its driver nor response.
    For "skew", pairs skew system with another skew system (intra- or inter-system).
    """
    pairs = []
    if which == "skew":
        # Compare skew against skew (intra- and inter-system, same dimension)
        for dim, skew_dim_dict in skew_filepaths.items():
            skew_systems = list(skew_dim_dict)
            # Intra-system pairs
            for skew_name, skew_files in skew_dim_dict.items():
                n = len(skew_files)
                if n >= 2:
                    # All unique pairs within the same skew system
                    all_pairs = [
                        (skew_files[i], skew_files[j])
                        for i in range(n)
                        for j in range(i + 1, n)
                    ]
                    pairs.extend(all_pairs)
            # Inter-system pairs
            if len(skew_systems) > 1:
                for i, sys_a in enumerate(skew_systems):
                    files_a = skew_dim_dict[sys_a]
                    for sys_b in skew_systems[i + 1 :]:
                        files_b = skew_dim_dict[sys_b]
                        pairs.extend([(a, b) for a in files_a for b in files_b])
        # Subsample if too many pairs
        if len(pairs) > num_pairs:
            pairs = rng.sample(pairs, num_pairs)
        return pairs
    else:
        for dim, skew_dim_dict in skew_filepaths.items():
            base_dim_dict = base_filepaths.get(dim)
            if not base_dim_dict:
                continue
            for skew_name, skew_files in skew_dim_dict.items():
                driver, response = parse_driver_response(skew_name)
                if which in ("driver", "response"):
                    base_name = driver if which == "driver" else response
                    if not base_name or base_name not in base_dim_dict:
                        continue
                    base_files = base_dim_dict[base_name]
                    n = min(num_pairs, len(skew_files), len(base_files))
                    if n == 0:
                        continue
                    pairs.extend(
                        zip(
                            rng.sample(skew_files, n)
                            if len(skew_files) > n
                            else skew_files,
                            rng.sample(base_files, n)
                            if len(base_files) > n
                            else base_files,
                        )
                    )
                elif which == "base":
                    # Exclude driver and response from base candidates
                    exclude = {driver, response}
                    base_candidates = [
                        name
                        for name in base_dim_dict
                        if name not in exclude and name is not None
                    ]
                    if not base_candidates:
                        continue
                    base_name = rng.choice(base_candidates)
                    base_files = base_dim_dict[base_name]
                    n = min(num_pairs, len(skew_files), len(base_files))
                    if n == 0:
                        continue
                    pairs.extend(
                        zip(
                            rng.sample(skew_files, n)
                            if len(skew_files) > n
                            else skew_files,
                            rng.sample(base_files, n)
                            if len(base_files) > n
                            else base_files,
                        )
                    )
        return rng.sample(pairs, num_pairs) if len(pairs) > num_pairs else pairs


skew_kld_results = {}

for which in ["driver", "response", "base", "skew"]:
    if which == "skew":
        print("Computing KLDs for skew-skew pairs...")
    else:
        print(f"Computing KLDs for skew-{which} vs. base system pairs...")
    pairs = sample_skew_vs_base_pairs(
        skew_filepaths, base_filepaths, which, num_pairs, rng
    )
    klds = (
        [
            kld
            for pair in tqdm(
                pairs,
                desc=f"KLD skew-{which} pairs"
                if which == "skew"
                else f"KLD skew-{which} pairs",
                leave=False,
            )
            for kld in (compute_klds([pair]) or [])
        ]
        if pairs
        else []
    )
    skew_kld_results[which] = {
        "pairs": pairs,
        "mean": np.mean(klds) if klds else None,
        "std": np.std(klds) if klds else None,
        "values": klds,
    }

# Print concise summary
for which, res in skew_kld_results.items():
    if which == "skew":
        print(
            f"Skew-skew pairs: mean KLD={res['mean']}, std={res['std']}, n={len(res['values'])}"
        )
    else:
        print(
            f"Skew-{which} vs. base system pairs: mean KLD={res['mean']}, std={res['std']}, n={len(res['values'])}"
        )

In [None]:
skew_kld_results