# Random-Split Self & Cross Evaluation

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from pathlib import Path

import sys
sys.path.append("../src/")
from structures import ConventionalCNN
from dataloader import loadCAMELS, split_expanded_dataset_from_json

dtype = torch.float32
torch.set_default_dtype(dtype)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

batch_size = 64
individual = False

## Helper Functions

In [None]:
import gc

def clear_memory():
    """Clear GPU and CPU memory."""
    gc.collect()
    torch.cuda.empty_cache()


def log_normalize(arr, individual=False, ref_stats=None):
    """Log-scale then normalize by mean/std."""
    arr = np.log10(arr)
    if ref_stats:
        mean_val, std_val = ref_stats
    else:
        if individual:
            mean_val = arr.mean(axis=(1, 2), keepdims=True)
            std_val = arr.std(axis=(1, 2), keepdims=True)
        else:
            mean_val, std_val = arr.mean(), arr.std()
    return (arr - mean_val) / std_val, (mean_val, std_val)


def _apply_sb35_windowing(data: np.ndarray, step: int = 10, length: int = 5) -> np.ndarray:
    """Match SB35_half/cutout sampling: take 5 frames every 10."""
    arr = np.arange(data.shape[0])
    starts = np.arange(0, len(arr) - length + 1, step)
    idx = (starts[:, None] + np.arange(length)[None, :]).reshape(-1)
    return data[idx]


def _apply_crop(data: np.ndarray, crop: str = "tl") -> np.ndarray:
    h, w = data.shape[1], data.shape[2]
    hh, ww = h // 2, w // 2
    if crop == "tl":
        return data[:, :hh, :ww]
    if crop == "tr":
        return data[:, :hh, ww:]
    if crop == "bl":
        return data[:, hh:, :ww]
    if crop == "br":
        return data[:, hh:, ww:]
    raise ValueError(f"Unknown crop: {crop}")


def _random_split_expanded_dataset(
    data: torch.Tensor,
    labels: torch.Tensor,
    chunk_size: int,
    val_ratio: float = 0.1,
    test_ratio: float = 0.1,
    seed: int = 42,
):
    """Random train/val/test split matching train_sb random mode."""
    n = data.shape[0]
    assert n % chunk_size == 0, "data length must be divisible by chunk_size"
    n_chunks = n // chunk_size
    idx = np.arange(n_chunks)
    rng = np.random.default_rng(seed)
    rng.shuffle(idx)

    n_test = int(test_ratio * n_chunks)
    n_val = int(val_ratio * n_chunks)
    test_idx = idx[:n_test]
    val_idx = idx[n_test : n_test + n_val]
    train_idx = idx[n_test + n_val :]

    def _gather(idxs):
        flat = np.concatenate([np.arange(i * chunk_size, (i + 1) * chunk_size) for i in idxs])
        return data[flat], labels[flat]

    train = _gather(train_idx)
    val = _gather(val_idx)
    test = _gather(test_idx)
    return train, val, test


def make_random_test_loader(data_norm, labels, chunk_size, seed, val_ratio=0.1, test_ratio=0.1, shuffle_loader=False):
    tensor = torch.tensor(data_norm, dtype=dtype)
    labels_t = torch.tensor(labels, dtype=dtype)
    _, _, test_set = _random_split_expanded_dataset(tensor, labels_t, chunk_size=chunk_size, val_ratio=val_ratio, test_ratio=test_ratio, seed=seed)
    return DataLoader(test_set, batch_size=batch_size, shuffle=shuffle_loader)


def load_model_single(path, input_shape=(256, 256)):
    model = ConventionalCNN(input_shape=input_shape, output_shape=1, H=16, output_positive=True).to(device)
    state = torch.load(path, map_location=device, weights_only=True)
    model.load_state_dict(state)
    return model


def load_model_pair(om_path, sig_path, input_shape=(256, 256)):
    return load_model_single(om_path, input_shape=input_shape), load_model_single(sig_path, input_shape=input_shape)


def get_predictions(model_om, model_sig, val_loader, minmax):
    model_om.eval()
    model_sig.eval()
    preds_om, preds_sig, truths = [], [], []

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            preds_om.append(model_om(inputs).cpu())
            preds_sig.append(model_sig(inputs).cpu())
            truths.append(targets.cpu())
            del inputs

    preds = np.c_[torch.cat(preds_om).numpy(), torch.cat(preds_sig).numpy()]
    truths = torch.cat(truths).numpy()
    del preds_om, preds_sig
    clear_memory()

    preds = preds * (minmax[:2, 1] - minmax[:2, 0]) + minmax[:2, 0]
    mse = ((preds - truths) ** 2).mean(axis=0)
    print(f"  MSE: Om={mse[0]:.4e}, Sig8={mse[1]:.4e}")
    return truths, preds, 2


def _metrics_point(y_true, y_pred, y_sigma=None, eps=1e-12):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    resid = y_pred - y_true
    ss_res = np.sum(resid ** 2)
    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2) + eps
    r2 = 1.0 - ss_res / ss_tot
    rmse = np.sqrt(np.mean(resid ** 2))
    denom = np.maximum(np.abs(y_true), eps)
    eps_rel = np.mean(np.abs(resid) / denom)
    if y_sigma is None:
        chi2_red = np.nan
    else:
        y_sigma = np.maximum(np.asarray(y_sigma), eps)
        chi2 = np.sum((resid / y_sigma) ** 2)
        dof = max(int(y_true.size) - 1, 1)
        chi2_red = chi2 / dof
    return r2, eps_rel, rmse, chi2_red


def _metrics_with_uncertainty(y_true, y_pred, y_sigma=None, n_boot=200, seed=0):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    if y_sigma is not None:
        y_sigma = np.asarray(y_sigma)
    n = y_true.shape[0]
    if n <= 1:
        r2, eps_rel, rmse, chi2_red = _metrics_point(y_true, y_pred, y_sigma=y_sigma)
        return (r2, 0.0), (eps_rel, 0.0), (rmse, 0.0), (chi2_red, 0.0)
    rng = np.random.default_rng(seed)
    samples = np.empty((n_boot, 4), dtype=float)
    for b in range(n_boot):
        idx = rng.integers(0, n, size=n)
        yt = y_true[idx]
        yp = y_pred[idx]
        ys = y_sigma[idx] if y_sigma is not None else None
        samples[b] = _metrics_point(yt, yp, y_sigma=ys)
    means = samples.mean(axis=0)
    stds = samples.std(axis=0, ddof=1)
    return (means[0], stds[0]), (means[1], stds[1]), (means[2], stds[2]), (means[3], stds[3])


def plot_comparison(data, chunk_size, title, num_samples=200, n_boot=200, seed=0):
    truths, preds, _ = data
    target_names = [r"$\Omega_m$", r"$\sigma_8$"]
    fig, axs = plt.subplots(1, 2, figsize=(10, 4))
    chunks = preds.reshape(-1, chunk_size, 2)
    pred_mean = chunks.mean(axis=1)
    pred_std = chunks.std(axis=1)
    true = truths[::chunk_size]
    n = min(num_samples, true.shape[0])
    for i, ax in enumerate(axs):
        x = true[:n, i]
        y = pred_mean[:n, i]
        yerr = pred_std[:n, i]
        (r2_m, r2_s), (eps_m, eps_s), (rmse_m, rmse_s), (chi2_m, chi2_s) = _metrics_with_uncertainty(x, y, y_sigma=yerr, n_boot=n_boot, seed=seed + i)
        eps_pct_m = 100.0 * eps_m
        eps_pct_s = 100.0 * eps_s
        lines = [
            rf"$R^2 = {r2_m:.4f} \pm {r2_s:.4f}$",
            rf"$\epsilon = ({eps_pct_m:.2f} \pm {eps_pct_s:.2f})\%$",
            rf"$\mathrm{RMSE} = {rmse_m:.3e} \pm {rmse_s:.3e}$",
            rf"$\chi^2_\nu = {chi2_m:.3e} \pm {chi2_s:.3e}$",
        ]
        textstr = "\n".join(lines)
        ax.errorbar(x, y, yerr=yerr, fmt='none', capsize=2, ecolor='tab:orange', alpha=0.7)
        ax.scatter(x, y, s=4, c='k', zorder=10)
        ax.plot([x.min(), x.max()], [x.min(), x.max()], 'r--', label='Ideal')
        ax.text(0.95, 0.05, textstr, transform=ax.transAxes, fontsize=9, va='bottom', ha='right', bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.5))
        ax.set_xlabel("Truth")
        ax.set_ylabel("Prediction")
        ax.set_title(target_names[i])
        ax.legend(loc='upper left')
        ax.grid(True, alpha=0.3)
    fig.suptitle(title, fontsize=12, fontweight='bold')
    plt.tight_layout()
    return fig


def compute_stats(data):
    truths, preds, _ = data
    return {
        'bias_om': np.mean(truths[:, 0] - preds[:, 0]),
        'scatter_om': np.std(truths[:, 0] - preds[:, 0]),
        'bias_sig': np.mean(truths[:, 1] - preds[:, 1]),
        'scatter_sig': np.std(truths[:, 1] - preds[:, 1]),
    }


def compute_scorebox_scores(data, chunk_size, n_boot=200, seed=0):
    truths, preds, _ = data
    chunks = preds.reshape(-1, chunk_size, 2)
    pred_mean = chunks.mean(axis=1)
    pred_std = chunks.std(axis=1)
    true = truths[::chunk_size]
    out = {}
    for i, tag in enumerate(['om', 'sig']):
        x = true[:, i]
        y = pred_mean[:, i]
        yerr = pred_std[:, i]
        (r2_m, r2_s), (eps_m, eps_s), (rmse_m, rmse_s), (chi2_m, chi2_s) = _metrics_with_uncertainty(x, y, y_sigma=yerr, n_boot=n_boot, seed=seed + i)
        out[f'r2_{tag}_m'] = float(r2_m)
        out[f'r2_{tag}_s'] = float(r2_s)
        out[f'eps_{tag}_pct_m'] = float(100.0 * eps_m)
        out[f'eps_{tag}_pct_s'] = float(100.0 * eps_s)
        out[f'rmse_{tag}_m'] = float(rmse_m)
        out[f'rmse_{tag}_s'] = float(rmse_s)
        out[f'chi2_{tag}_m'] = float(chi2_m)
        out[f'chi2_{tag}_s'] = float(chi2_s)
    return out


def _ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def _slugify(name: str) -> str:
    return name.replace(':', '').replace(' ', '_').replace('|', '-').replace('>', 'to').replace('/', '-')


def run_evaluation(model_om, model_sig, val_loader, minmax, chunk_size, title, num_samples=200, save_dir=None, base_name=None):
    print(f"\n{title}")
    data = get_predictions(model_om, model_sig, val_loader, minmax)
    fig = plot_comparison(data, chunk_size, title, num_samples=num_samples)
    if save_dir is not None:
        _ensure_dir(Path(save_dir))
        stem = _slugify(base_name or title)
        fig.savefig(Path(save_dir) / f"{stem}.png", dpi=200, bbox_inches='tight')
        fig.savefig(Path(save_dir) / f"{stem}.pdf", bbox_inches='tight')
    plt.show()
    stats = compute_stats(data)
    stats.update(compute_scorebox_scores(data, chunk_size=chunk_size, n_boot=200, seed=0))
    del data, fig
    clear_memory()
    return stats

## Config

In [None]:
model_dir = Path("../data/models")
chunk_sb28 = 15
chunk_sb35_cutout = 15
val_ratio = 0.1
test_ratio = 0.1
seeds = [0, 42, 123, 456, 789]

results_dir = Path("../plot/random_cross_eval")
plots_dir = results_dir / "figures"
tables_dir = results_dir / "tables"
_ensure_dir(plots_dir)
_ensure_dir(tables_dir)

all_results = []

## Load Data and Stats

In [None]:
# SB28
_dataSB28, _labelsSB28, minmaxSB28 = loadCAMELS(field="Mtot", box="SB28", normalization=False, individual=individual)
labelsSB28 = torch.tensor(_labelsSB28[:, :2], dtype=dtype)
dataSB28_norm, statsSB28 = log_normalize(_dataSB28, individual=individual)

# SB35 cutout
_dataSB35, _labelsSB35, minmaxSB35 = loadCAMELS(field="Mtot", box="SB35", normalization=False, individual=individual)
labelsSB35 = torch.tensor(_labelsSB35[:, :2], dtype=dtype)
dataSB35_window = _apply_sb35_windowing(_dataSB35)
dataSB35_cutout = _apply_crop(dataSB35_window, crop="tl")
dataSB35_cutout_norm, statsSB35_full = log_normalize(dataSB35_cutout, individual=individual)
statsSB35_cutout = statsSB35_full

del dataSB35_window
clear_memory()

## Self Validation (random splits by seed)

In [None]:
self_results = []

for seed in seeds:
    # SB28 self
    loader_sb28 = make_random_test_loader(dataSB28_norm, labelsSB28, chunk_size=chunk_sb28, seed=seed, val_ratio=val_ratio, test_ratio=test_ratio)
    model_om, model_sig = load_model_pair(model_dir / f"SB28_om_seed{seed}_best.pt", model_dir / f"SB28_sig_seed{seed}_best.pt", input_shape=(256, 256))
    res = run_evaluation(model_om, model_sig, loader_sb28, minmaxSB28, chunk_sb28, f"SB28 self (seed={seed})", save_dir=plots_dir, base_name=f"sb28_seed{seed}_self")
    self_results.append((f"SB28 self seed {seed}", res))
    del loader_sb28, model_om, model_sig
    clear_memory()

    # SB35_cutout self
    loader_sb35 = make_random_test_loader(dataSB35_cutout_norm, labelsSB35, chunk_size=chunk_sb35_cutout, seed=seed, val_ratio=val_ratio, test_ratio=test_ratio)
    model_om, model_sig = load_model_pair(model_dir / f"SB35_cutout15_om_seed{seed}_best.pt", model_dir / f"SB35_cutout15_sig_seed{seed}_best.pt", input_shape=(256, 256))
    res = run_evaluation(model_om, model_sig, loader_sb35, minmaxSB35, chunk_sb35_cutout, f"SB35_cutout self (seed={seed})", save_dir=plots_dir, base_name=f"sb35cutout_seed{seed}_self")
    self_results.append((f"SB35_cutout self seed {seed}", res))
    del loader_sb35, model_om, model_sig
    clear_memory()

## Cross Validation SB28 ↔ SB35_cutout (matching seeds)

In [None]:
cross_results = []

for seed in seeds:
    # Prepare SB35_cutout data normalized with SB28 stats for SB28 model
    data_cutout_norm_sb28, _ = log_normalize(dataSB35_cutout, individual=individual, ref_stats=statsSB28)
    loader_sb35_for_sb28 = make_random_test_loader(data_cutout_norm_sb28, labelsSB35, chunk_size=chunk_sb35_cutout, seed=seed, val_ratio=val_ratio, test_ratio=test_ratio)
    model_om, model_sig = load_model_pair(model_dir / f"SB28_om_seed{seed}_best.pt", model_dir / f"SB28_sig_seed{seed}_best.pt", input_shape=(256, 256))
    res = run_evaluation(model_om, model_sig, loader_sb35_for_sb28, minmaxSB28, chunk_sb35_cutout, f"SB28 model on SB35_cutout (seed={seed})", save_dir=plots_dir, base_name=f"sb28_seed{seed}_to_sb35cutout")
    cross_results.append((f"SB28 → SB35_cutout seed {seed}", res))
    del data_cutout_norm_sb28, loader_sb35_for_sb28, model_om, model_sig
    clear_memory()

    # Prepare SB28 data normalized with SB35_cutout stats for SB35 model
    data_sb28_norm_sb35, _ = log_normalize(_dataSB28, individual=individual, ref_stats=statsSB35_cutout)
    loader_sb28_for_sb35 = make_random_test_loader(data_sb28_norm_sb35, labelsSB28, chunk_size=chunk_sb28, seed=seed, val_ratio=val_ratio, test_ratio=test_ratio)
    model_om, model_sig = load_model_pair(model_dir / f"SB35_cutout15_om_seed{seed}_best.pt", model_dir / f"SB35_cutout15_sig_seed{seed}_best.pt", input_shape=(256, 256))
    res = run_evaluation(model_om, model_sig, loader_sb28_for_sb35, minmaxSB35, chunk_sb28, f"SB35_cutout model on SB28 (seed={seed})", save_dir=plots_dir, base_name=f"sb35cutout_seed{seed}_to_sb28")
    cross_results.append((f"SB35_cutout → SB28 seed {seed}", res))
    del data_sb28_norm_sb35, loader_sb28_for_sb35, model_om, model_sig
    clear_memory()

## Save Summary Tables

In [None]:
def _pm(m, s, fmt):
    if m is None or s is None:
        return 'nan'
    if not np.isfinite(m) or not np.isfinite(s):
        return 'nan'
    return (fmt.format(m) + ' ± ' + fmt.format(s))

def _write_table(results, prefix):
    lines = []
    lines.append('=' * 110)
    lines.append(f'{prefix} SUMMARY')
    lines.append('=' * 110)
    header = f"{'Experiment':<40} {'Target':<6} {'R^2':>18} {'epsilon[%]':>18} {'RMSE':>18} {'chi2_nu':>18}"
    lines.append(header)
    lines.append('-' * 110)
    print('
' + lines[0])
    print(lines[1])
    print(lines[0])
    print(header)
    print('-' * 110)
    for name, stats in results:
        for tag, label in [('om', 'Om'), ('sig', 'Sig8')]:
            r2 = _pm(stats.get(f'r2_{tag}_m'), stats.get(f'r2_{tag}_s'), '{:.4f}')
            eps = _pm(stats.get(f'eps_{tag}_pct_m'), stats.get(f'eps_{tag}_pct_s'), '{:.2f}') + '%'
            rmse = _pm(stats.get(f'rmse_{tag}_m'), stats.get(f'rmse_{tag}_s'), '{:.3e}')
            chi2 = _pm(stats.get(f'chi2_{tag}_m'), stats.get(f'chi2_{tag}_s'), '{:.3e}')
            row = f"{name:<40} {label:<6} {r2:>18} {eps:>18} {rmse:>18} {chi2:>18}"
            print(row)
            lines.append(row)
        print('-' * 110)
        lines.append('-' * 110)
    with open(tables_dir / f"{prefix.lower()}_summary.txt", 'w') as f:
        f.write('
'.join(lines))
    tsv = ['Experiment	Target	R2	R2_std	epsilon_pct	epsilon_pct_std	RMSE	RMSE_std	chi2_nu	chi2_nu_std']
    for name, stats in results:
        for tag, label in [('om', 'Om'), ('sig', 'Sig8')]:
            r2m, r2s = stats.get(f'r2_{tag}_m'), stats.get(f'r2_{tag}_s')
            epsm, epss = stats.get(f'eps_{tag}_pct_m'), stats.get(f'eps_{tag}_pct_s')
            rmsem, rmses = stats.get(f'rmse_{tag}_m'), stats.get(f'rmse_{tag}_s')
            chi2m, chi2s = stats.get(f'chi2_{tag}_m'), stats.get(f'chi2_{tag}_s')
            tsv.append(f"{name}	{label}	{r2m}	{r2s}	{epsm}	{epss}	{rmsem}	{rmses}	{chi2m}	{chi2s}")
    with open(tables_dir / f"{prefix.lower()}_summary.tsv", 'w') as f:
        f.write('
'.join(tsv))

_write_table(self_results, 'SELF')
_write_table(cross_results, 'CROSS')