# Low-Rank MLP Experiment Playground

Fresh notebook for building and evaluating low-rank MLPs on DA-MH chains with reproducible metrics.

## Notebook Outline
1. Bring in the low-rank MLP module and shared helpers from `test_mlp.py`.
2. Configure architecture ranks, training windows, validation tail, and optimizer knobs.
3. Load the HDF5 chain, prepare helper utilities, and train through a list of checkpoints.
4. Log per-checkpoint metrics plus the L1 log-likelihood error on the 50k+ chain tail for consistent comparisons.

In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from pathlib import Path
from copy import deepcopy
from types import SimpleNamespace

from rank_mlp_double import (
    LowRankLinear,
    LowRankMLP,
    build_train_sizes,
    collect_training_range,
    prepare_training_arrays,
    logpi_l1_error,
    run_training_cycle,
)

from test_mlp import (
    load_data,
    train_mlp,
    standardize_features,
    apply_standardization,
    unique_preserve_order,
    log_posterior_unnorm_numpy,
)


In [None]:
# Variant notes (diffs vs this base config):
# Base notebook logs results to rank_mlp_progress_new.csv.
# - rank_MLP copy: trains a single 40k-step window, disables warm starts, ranks [10,1,1,1]; results rank_mlp_progress.csv.
# - rank_MLP2: keeps 1k warm-start windows but limits ranks to [1,1,1,1] with compression ratio 0.01; results rank_mlp_progress2.csv.
# - rank_MLP3: widens to hidden_dim=1024 with 3 low-rank layers at rank 2; results rank_mlp_progress4.csv.
# - rank_MLP4: uses 2k windows (6 checkpoints) with hidden_dim=1024 and rank-2 layers; results rank_mlp_progress5.csv.
# - rank_MLP5: 1k windows, compression ratio 0.1, 4 layers hidden_dim=1024, ranks [2,2,2], noise_std=0.1; results rank_mlp_progress6.csv.
# - rank_MLP_2_layers_1: 2-layer variant, hidden_dim=1024, ranks [10,10], trains with MSE; results rank_mlp_progress_2L_mse_1.csv.
# - rank_MLP_2_layers_2: 2-layer variant, hidden_dim=512, ranks [2,2], MSE loss; results rank_mlp_progress_2L_mse_2.csv.
# - rank_MLP_2_layers_3: 2-layer variant, hidden_dim=1024, ranks [10,10], compression ratio 0.025; results rank_mlp_progress_2L_10_3.csv.
# - rank_MLP_2_layers_4: 2-layer variant, hidden_dim=2048, ranks [2,2], compression ratio 0.05; results rank_mlp_progress_2L_4.csv.
# - rank_MLP_2_layers_5: 2-layer variant, hidden_dim=1024, ranks [10,10], compression ratio 0.05; results rank_mlp_progress_2L_5.csv.
# - rank_MLP_2_layers_double: 2k windows (10 checkpoints), no warm starts, noise_std=0.1, 80 train loops, batch 32; results rank_mlp_progress_2L_double_cold.csv.
# - rank_MLP_2_layers_double2: same as previous double variant but logs to rank_mlp_progress_2L_double_cold2.csv.


DATA_PATH = "data1.h5"
SIGMA_PRIOR = 1.0
SIGMA_LIK = 0.3

TRAIN_START_STEP = 0
WINDOW_SIZE = 1000
NUM_TRAIN_WINDOWS = 30
MAX_TOTAL_TRAIN_STEPS = 40000
MASTER_VAL_START = 50000
MASTER_VAL_LENGTH = None  # Use the remainder of the chain when None.

USE_STANDARDIZATION = False
WARM_START = True
RESULTS_CSV = "rank_mlp_progress_new.csv"

GROWTH_COMPRESSION_RATIO = 0.025
IMPROVEMENT_TOL = 1e-5

LOW_RANK_ARCH = {
    "hidden_dim": 256,
    "num_hidden_layers": 5,
    "ranks": [20, 20, 20, 20],
    "activation": nn.Tanh(),
    "noise_std": 0.01,
    "apply_final_activation": True,
}

TRAINING_CFG = {
    "max_adam_epochs": 1000,
    "adam_lr": 1e-3,
    "adam_patience": 100,
    "tol": 1e-5,
    "max_lbfgs_iter": 50,
    "loss_name": "l1",
    "train_loops": 40,
    "batch_size": 64,
    "loss_domain": "obs",
    "batch_growth": 1.2,
    "verbose": 1,
    "loop_improvement_pct": 0.1,
}

TRAINING_CFG_NS = SimpleNamespace(**TRAINING_CFG)

SEED = 123
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE


In [None]:
par, obs, y_obs, chain, props, logpi_true = load_data(DATA_PATH, SIGMA_PRIOR, SIGMA_LIK)
input_dim = par.shape[1]
output_dim = obs.shape[1]

train_sizes = build_train_sizes(
    total_chain=chain.shape[0],
    train_start=TRAIN_START_STEP,
    window=WINDOW_SIZE,
    max_total_steps=MAX_TOTAL_TRAIN_STEPS,
    num_windows=NUM_TRAIN_WINDOWS,
)

print(f"Device: {DEVICE}")
print(f"Input dim: {input_dim}, output dim: {output_dim}")
print(f"Window size: {WINDOW_SIZE}, total train checkpoints: {len(train_sizes)}")
print(f"Train checkpoints (first up to 10 entries): {train_sizes[:10]}{'...' if len(train_sizes) > 10 else ''}")

if MASTER_VAL_LENGTH is None:
    master_val_length = None
else:
    master_val_length = min(MASTER_VAL_LENGTH, max(0, chain.shape[0] - MASTER_VAL_START))


In [None]:
def make_low_rank_model() -> LowRankMLP:
    return LowRankMLP(
        input_dim=input_dim,
        hidden_dim=LOW_RANK_ARCH["hidden_dim"],
        output_dim=output_dim,
        num_hidden_layers=LOW_RANK_ARCH["num_hidden_layers"],
        ranks=LOW_RANK_ARCH["ranks"],
        activation=LOW_RANK_ARCH.get("activation", nn.Tanh()),
        noise_std=LOW_RANK_ARCH.get("noise_std", 0.01),
        apply_final_activation=LOW_RANK_ARCH.get("apply_final_activation", True),
    )

def select_layer_for_growth(model: LowRankMLP) -> tuple[int | None, float | None]:
    singular_lists = model.singular_values_by_layer()
    best_idx: int | None = None
    best_ratio: float | None = None
    for idx, sv in enumerate(singular_lists):
        if sv.numel() == 0:
            continue
        max_sv = torch.max(sv).item()
        if max_sv <= 0:
            continue
        min_sv = torch.min(sv).item()
        ratio = min_sv / max_sv
        if best_ratio is None or ratio > best_ratio:
            best_ratio = ratio
            best_idx = idx
    return best_idx, best_ratio


In [None]:

from collections.abc import Iterable

def plot_singular_values_at_ranks(
    models,
    max_rank: Optional[int] = None,
    figsize=(7, 4),
) -> None:
    """Plot singular spectra for every low-rank layer in each supplied model."""
    def normalize_models(arg):
        if isinstance(arg, LowRankMLP):
            return [("model_0", arg)]
        if isinstance(arg, dict):
            return [(str(label), model) for label, model in arg.items()]
        if isinstance(arg, Iterable):
            normalized = []
            for idx, entry in enumerate(arg):
                if (
                    isinstance(entry, tuple)
                    and len(entry) == 2
                    and isinstance(entry[0], str)
                ):
                    label, model = entry
                else:
                    label = f"model_{idx}"
                    model = entry
                normalized.append((label, model))
            return normalized
        return [("model_0", arg)]

    labeled = normalize_models(models)
    if not labeled:
        raise ValueError("No models provided for plotting.")

    for label, model in labeled:
        if not isinstance(model, LowRankMLP):
            raise TypeError(
                f"Expected LowRankMLP instances, got {type(model).__name__} for label {label}."
            )
        sv_lists = model.singular_values_by_layer(max_rank=max_rank)
        if not sv_lists:
            print(f"[plot] {label}: no low-rank layers to visualize.")
            continue

        fig, ax = plt.subplots(figsize=figsize)
        for layer_idx, sv_tensor in enumerate(sv_lists, start=1):
            if sv_tensor.numel() == 0:
                continue
            xs = np.arange(1, sv_tensor.numel() + 1)
            ax.plot(xs, sv_tensor.numpy(), marker='o', label=f"layer {layer_idx}")
        ax.set_yscale('log')
        ax.set_xlabel('rank index')
        ax.set_ylabel('singular value')
        ax.set_title(f'Singular values ({label})')
        ax.legend()
        ax.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.5)
        plt.show()


In [None]:
checkpoint_records = []
previous_model = None

for step_idx, train_steps in enumerate(train_sizes, start=1):
    print(f"=== Step {step_idx}: train on first {train_steps} steps (window={WINDOW_SIZE}) ===")
    train_indices = collect_training_range(chain, props, TRAIN_START_STEP, train_steps)
    X_train_raw, X_train_proc, y_train, logpi_train, x_mean, x_std = prepare_training_arrays(
        par, obs, logpi_true, train_indices, USE_STANDARDIZATION
    )

    def run_cycle(candidate_model: LowRankMLP) -> float:
        return run_training_cycle(
            candidate_model,
            X_train_proc,
            y_train,
            X_train_raw,
            logpi_train,
            DEVICE,
            TRAINING_CFG_NS,
            SIGMA_PRIOR,
            SIGMA_LIK,
            y_obs,
        )

    def eval_logpi(candidate_model: LowRankMLP, start: int, length: int | None) -> float:
        return logpi_l1_error(
            candidate_model,
            par,
            obs,
            logpi_true,
            y_obs,
            chain,
            props,
            start,
            length,
            USE_STANDARDIZATION,
            x_mean,
            x_std,
            SIGMA_PRIOR,
            SIGMA_LIK,
            DEVICE,
        )

    if WARM_START and previous_model is not None:
        model = deepcopy(previous_model)
    else:
        model = make_low_rank_model()
    model.to(DEVICE)

    base_train_loss = run_cycle(model)
    print("Ranks before:", model.ranks)
    model.contract_ranks_by_ratio(GROWTH_COMPRESSION_RATIO)
    print("Ranks after initial compression:", model.ranks)
    base_train_loss = run_cycle(model)

    val_start = TRAIN_START_STEP + train_steps
    val_error = eval_logpi(model, val_start, WINDOW_SIZE)
    master_error = eval_logpi(model, MASTER_VAL_START, master_val_length)
    base_ranks = tuple(int(r) for r in getattr(model, "ranks", []))

    best_model = deepcopy(model)
    best_error = val_error
    best_master_error = master_error
    best_train_loss = base_train_loss

    growth_trials: list[dict] = []
    trial_idx = 0

    while True:
        layer_idx, ratio = select_layer_for_growth(best_model)
        if layer_idx is None:
            print("  No eligible low-rank layer for further expansion.")
            break

        candidate = deepcopy(best_model)
        ranks_before = tuple(int(r) for r in candidate.ranks)
        max_rank_allowed = candidate.hidden_dim
        new_rank = min(max_rank_allowed, max(1, ranks_before[layer_idx] * 2))
        if new_rank == ranks_before[layer_idx]:
            print("  Selected layer already at maximum rank; stopping growth loop.")
            break

        target_ranks = list(ranks_before)
        target_ranks[layer_idx] = new_rank
        candidate.contract_ranks_by_amount(target_ranks)
        trial_idx += 1
        print(
            f"  [growth][train={train_steps}] trial {trial_idx}: layer {layer_idx}, "
            f"ratio={ratio:.3e} -> rank {ranks_before[layer_idx]}→{new_rank}"
        )

        expand_loss = run_cycle(candidate)
        ranks_after_expand = tuple(int(r) for r in candidate.ranks)

        candidate.contract_ranks_by_ratio(GROWTH_COMPRESSION_RATIO)
        ranks_after_compress = tuple(int(r) for r in candidate.ranks)

        compress_loss = run_cycle(candidate)

        candidate_val_error = eval_logpi(candidate, val_start, WINDOW_SIZE)
        candidate_master_error = eval_logpi(candidate, MASTER_VAL_START, master_val_length)
        improved = candidate_val_error < best_error - IMPROVEMENT_TOL

        growth_trials.append({
            "trial": trial_idx,
            "layer": int(layer_idx),
            "sv_ratio": float(ratio) if ratio is not None else float("nan"),
            "ranks_before": ranks_before,
            "ranks_after_expand": ranks_after_expand,
            "ranks_after_compress": ranks_after_compress,
            "expand_loss": float(expand_loss),
            "compress_loss": float(compress_loss),
            "prev_best_val": float(best_error),
            "candidate_val": float(candidate_val_error),
            "improved": bool(improved),
        })

        print(
            f"      ranks: before {ranks_before} | expand {ranks_after_expand} | "
            f"compress {ranks_after_compress}"
        )
        print(
            f"      val logpi: {best_error:.4e} -> {candidate_val_error:.4e} | improved={improved}"
        )

        if improved:
            best_model = deepcopy(candidate)
            best_error = candidate_val_error
            best_master_error = candidate_master_error
            best_train_loss = compress_loss
        else:
            print("      Improvement threshold not met; ending growth loop.")
            break

    final_model = best_model
    if WARM_START:
        previous_model = deepcopy(final_model)
    else:
        previous_model = None

    record = {
        "train_steps": int(train_steps),
        "unique_train_samples": int(train_indices.size),
        "val_window_start": val_start,
        "val_window_length": WINDOW_SIZE,
        "base_train_loss": float(base_train_loss),
        "final_train_loss": float(best_train_loss),
        "base_val_logpi_l1": float(val_error),
        "final_val_logpi_l1": float(best_error),
        "master_logpi_l1": float(best_master_error),
        "base_ranks": base_ranks,
        "final_ranks": tuple(int(r) for r in final_model.ranks),
        "num_growth_trials": len(growth_trials),
        "growth_trials": growth_trials,
    }
    checkpoint_records.append(record)

    if RESULTS_CSV:
        out_path = Path(RESULTS_CSV)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        pd.DataFrame(checkpoint_records).to_csv(out_path, index=False)
        print(f"  Saved progress to {out_path}")


In [None]:
results_df = pd.DataFrame(checkpoint_records)
results_df
