# STRAW - Experiment 0

Interactive demo notebook. All logic lives in the `straw/` package â€”
this notebook is for quick exploration and visualisation.

For full experiment runs, use the CLI:
```bash
python run_experiment.py                          # all experiments
python run_experiment.py --experiment full_rank16  # single experiment
```

In [None]:
import torch
import yaml

from straw.models import ExecutorNet, ModulatorNet, build_resnet34, MODEL_REGISTRY
from straw.data import get_datasets, get_dataloaders
from straw.training import Trainer
from straw.evaluation import evaluate_model, run_evaluation_suite
from straw.visualization.plots import plot_comparison, plot_evaluation_bars, plot_evaluation_suite

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load config & data

In [None]:
# Load experiment config (or override inline)
with open("configs/experiment.yaml") as f:
    raw_config = yaml.safe_load(f)

defaults = raw_config["defaults"]
print(defaults)

In [None]:
# Pick an experiment (change index or override values)
exp_cfg = {**defaults, **raw_config["experiments"][0]}  # full_rank16
exp_cfg["device"] = device
print(exp_cfg)

In [None]:
datasets = get_datasets(
    data_root=exp_cfg.get("data_root", "./data"),
    train_subset=exp_cfg["train_subset"],
)
loaders = get_dataloaders(datasets, batch_size=exp_cfg["batch_size"])

## 2. Model overview

In [None]:
# Quick parameter comparison
import pandas as pd

count = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad)

executor = ExecutorNet()
modulator = ModulatorNet(ExecutorNet(), rank=exp_cfg["rank"])
resnet = build_resnet34()

df = pd.DataFrame([
    {"Model": "Standalone (Executor)", "Params": f"{count(executor):,}", "Type": "Baseline"},
    {"Model": "ResNet34 (Modified)",   "Params": f"{count(resnet):,}",   "Type": "Industry Standard"},
    {"Model": "Modulator (Hypernet)",  "Params": f"{count(modulator):,}", "Type": f"Ours (Rank {exp_cfg['rank']})"},
])
print(df.to_string(index=False))

## 3. Train all models

In [None]:
trainer = Trainer(device=device, lr=exp_cfg["lr"], batch_size=exp_cfg["batch_size"])

results = {}
for model_key in exp_cfg["models"]:
    model = MODEL_REGISTRY[model_key](exp_cfg)
    results[model_key] = trainer.train(
        model=model,
        train_dataset=datasets["train"],
        num_epochs=exp_cfg["num_epochs"],
        model_name=model_key,
    )

In [None]:
# Training curves
names = list(results.keys())
plot_comparison(
    [results[n].loss_history for n in names],
    [results[n].accuracy_history for n in names],
    names,
)

## 4. Evaluate on all test sets

In [None]:
trained_models = {k: v.model for k, v in results.items()}
test_loaders = {k: v for k, v in loaders.items() if k != "train"}

eval_results = run_evaluation_suite(trained_models, test_loaders, device)
plot_evaluation_suite(eval_results)

## 5. Run another experiment (change config inline)

Just change the experiment index or override parameters and re-run cells 1-4.

In [None]:
# Example: run 10% subset experiment
# exp_cfg = {**defaults, **raw_config["experiments"][1]}  # 10pct_rank16
# exp_cfg["device"] = device
# ... then re-run cells above