<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/meta_intelligence_e2e_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
meta_intelligence_e2e.py

Notebook/Colab-safe end-to-end toolkit for training, evaluating, predicting,
and benchmarking a compact MLP on synthetic datasets.

Subcommands:
- train:     train a model (self/supervised/hybrid)
- eval:      evaluate a checkpoint on a fresh synthetic test set
- predict:   predict for inline points or CSV rows
- benchmark: run multiple configurations over several seeds, summarize results
- export-onnx: export a trained checkpoint to ONNX (if onnx available)

Artifacts in <run_dir>:
- config.json
- metrics.csv
- model.pt
- metrics.png        (if matplotlib available)
- boundary.png       (if dim == 2 and matplotlib available)
"""

import argparse
import csv
import json
import math
import os
import random
import sys
import time
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional, Tuple

import numpy as np

# Optional plotting (script runs without it)
try:
    import matplotlib.pyplot as plt
    _HAS_MPL = True
except Exception:
    _HAS_MPL = False

# Optional ONNX export
try:
    import onnx  # noqa: F401
    import torch.onnx
    _HAS_ONNX = True
except Exception:
    _HAS_ONNX = False

import torch
import torch.nn as nn
import torch.nn.functional as F


# --- Notebook-safe CLI utilities ------------------------------------------------

def sanitize_argv(argv: Optional[List[str]] = None) -> List[str]:
    """Strip Jupyter/Colab's '-f <kernel.json>' and stray kernel json args."""
    if argv is None:
        argv = sys.argv[1:]
    cleaned, skip = [], False
    for a in argv:
        if skip:
            skip = False
            continue
        if a == "-f":
            skip = True
            continue
        if a.endswith(".json") and ("jupyter" in a or "kernel" in a):
            continue
        cleaned.append(a)
    return cleaned


# --- Reproducibility and misc ---------------------------------------------------

def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def ts() -> str:
    return time.strftime("%Y%m%d-%H%M%S")


def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


# --- Data generation ------------------------------------------------------------

@dataclass
class DataConfig:
    dataset: str = "blobs"     # blobs | moons | circles
    dim: int = 2
    n_classes: int = 2
    radius: float = 3.0        # blobs base radius
    spread: float = 1.0        # blobs spread
    extra_noise: float = 0.5   # for dims > 2
    moons_noise: float = 0.15  # moons noise
    circles_noise: float = 0.1 # circles noise
    circles_factor: float = 0.5  # inner/outer radius ratio


def make_dataset(n_samples: int, cfg: DataConfig, seed: int) -> Tuple[np.ndarray, np.ndarray]:
    if cfg.dataset == "blobs":
        return make_blobs(n_samples, cfg, seed)
    elif cfg.dataset == "moons":
        return make_moons(n_samples, cfg, seed)
    elif cfg.dataset == "circles":
        return make_circles(n_samples, cfg, seed)
    else:
        raise ValueError(f"Unknown dataset: {cfg.dataset}")


def pad_dims(X: np.ndarray, target_dim: int, noise: float, rng: np.random.Generator) -> np.ndarray:
    if X.shape[1] >= target_dim:
        return X[:, :target_dim].astype(np.float32)
    extra = rng.normal(0, noise, size=(X.shape[0], target_dim - X.shape[1])).astype(np.float32)
    return np.hstack([X.astype(np.float32), extra])


def make_blobs(n_samples: int, cfg: DataConfig, seed: int) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    n0 = n_samples // 2
    n1 = n_samples - n0
    ang0, ang1 = 0.0, math.pi
    c0 = np.array([cfg.radius * math.cos(ang0), cfg.radius * math.sin(ang0)], dtype=np.float32)
    c1 = np.array([cfg.radius * math.cos(ang1), cfg.radius * math.sin(ang1)], dtype=np.float32)
    x0 = rng.normal(0, cfg.spread, size=(n0, 2)).astype(np.float32) + c0
    x1 = rng.normal(0, cfg.spread, size=(n1, 2)).astype(np.float32) + c1
    X = np.vstack([x0, x1])
    y = np.concatenate([np.zeros(n0, dtype=np.int64), np.ones(n1, dtype=np.int64)], axis=0)
    X = pad_dims(X, cfg.dim, cfg.extra_noise, rng)
    idx = rng.permutation(n_samples)
    return X[idx], y[idx]


def make_moons(n_samples: int, cfg: DataConfig, seed: int) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    n0 = n_samples // 2
    n1 = n_samples - n0
    t0 = rng.uniform(0, math.pi, size=n0)
    t1 = rng.uniform(0, math.pi, size=n1)
    r = 1.0
    x0 = np.stack([r * np.cos(t0), r * np.sin(t0)], axis=1)
    x1 = np.stack([r * np.cos(t1) + 1.0, -r * np.sin(t1) - 0.5], axis=1)
    x0 += rng.normal(0, cfg.moons_noise, size=x0.shape)
    x1 += rng.normal(0, cfg.moons_noise, size=x1.shape)
    X = np.vstack([x0, x1]).astype(np.float32)
    y = np.concatenate([np.zeros(n0, dtype=np.int64), np.ones(n1, dtype=np.int64)], axis=0)
    X *= cfg.radius
    X = pad_dims(X, cfg.dim, cfg.extra_noise, rng)
    idx = rng.permutation(n_samples)
    return X[idx], y[idx]


def make_circles(n_samples: int, cfg: DataConfig, seed: int) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    n0 = n_samples // 2
    n1 = n_samples - n0
    t0 = rng.uniform(0, 2 * math.pi, size=n0)
    t1 = rng.uniform(0, 2 * math.pi, size=n1)
    r_out = cfg.radius
    r_in = cfg.radius * cfg.circles_factor
    x0 = np.stack([r_out * np.cos(t0), r_out * np.sin(t0)], axis=1)
    x1 = np.stack([r_in * np.cos(t1), r_in * np.sin(t1)], axis=1)
    x0 += rng.normal(0, cfg.circles_noise, size=x0.shape)
    x1 += rng.normal(0, cfg.circles_noise, size=x1.shape)
    X = np.vstack([x0, x1]).astype(np.float32)
    y = np.concatenate([np.zeros(n0, dtype=np.int64), np.ones(n1, dtype=np.int64)], axis=0)
    X = pad_dims(X, cfg.dim, cfg.extra_noise, rng)
    idx = rng.permutation(n_samples)
    return X[idx], y[idx]


def split_dataset(X: np.ndarray, y: np.ndarray, val_ratio: float, test_ratio: float, seed: int):
    rng = np.random.default_rng(seed)
    n = X.shape[0]
    idx = rng.permutation(n)
    n_test = int(test_ratio * n)
    n_val = int(val_ratio * n)
    test_idx = idx[:n_test]
    val_idx = idx[n_test:n_test + n_val]
    train_idx = idx[n_test + n_val:]
    return (X[train_idx], y[train_idx],
            X[val_idx], y[val_idx],
            X[test_idx], y[test_idx])


def to_tensor(x: np.ndarray, y: Optional[np.ndarray], device: torch.device):
    xt = torch.tensor(x, dtype=torch.float32, device=device)
    yt = None if y is None else torch.tensor(y, dtype=torch.long, device=device)
    return xt, yt


def augment_noise(x: torch.Tensor, sigma: float = 0.25) -> torch.Tensor:
    return x + sigma * torch.randn_like(x)


# --- Model ---------------------------------------------------------------------

class MLP(nn.Module):
    def __init__(self, in_dim: int, n_classes: int, hidden: int = 64, depth: int = 2, dropout: float = 0.0):
        super().__init__()
        layers: List[nn.Module] = []
        last = in_dim
        for _ in range(depth):
            layers += [nn.Linear(last, hidden), nn.ReLU(inplace=True)]
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            last = hidden
        layers.append(nn.Linear(last, n_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


# --- Loss helpers and metrics ---------------------------------------------------

def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
    p = logits.softmax(dim=-1).clamp_min(1e-12)
    return (-(p * p.log()).sum(dim=-1)).mean()


def cross_entropy_soft_targets(logits: torch.Tensor, target_probs: torch.Tensor) -> torch.Tensor:
    logp = logits.log_softmax(dim=-1)
    return (-(target_probs * logp).sum(dim=-1)).mean()


def one_hot(n_classes: int, labels: torch.Tensor) -> torch.Tensor:
    return F.one_hot(labels, num_classes=n_classes).float()


def label_smooth(target_one_hot: torch.Tensor, smoothing: float) -> torch.Tensor:
    if smoothing <= 0.0:
        return target_one_hot
    K = target_one_hot.shape[-1]
    u = torch.full_like(target_one_hot, 1.0 / K)
    return (1.0 - smoothing) * target_one_hot + smoothing * u


def consistency_kl(logits_a: torch.Tensor, logits_b: torch.Tensor) -> torch.Tensor:
    pa = logits_a.softmax(dim=-1).clamp_min(1e-8)
    pb = logits_b.softmax(dim=-1).clamp_min(1e-8)
    kl_ab = (pa * (pa.log() - pb.log())).sum(dim=-1)
    kl_ba = (pb * (pb.log() - pa.log())).sum(dim=-1)
    return 0.5 * (kl_ab + kl_ba).mean()


@torch.no_grad()
def accuracy(logits: torch.Tensor, y: torch.Tensor) -> float:
    pred = logits.argmax(dim=-1)
    return (pred == y).float().mean().item()


# --- Training + evaluation ------------------------------------------------------

@dataclass
class TrainConfig:
    mode: str
    steps: int
    batch_size: int
    lr: float
    wd: float
    entropy_bonus: float
    label_sharpen: float
    seed: int
    outdir: str
    log_every: int
    dim: int
    n_classes: int
    train_size: int
    val_ratio: float
    test_ratio: float
    dataset: str
    hidden: int
    depth: int
    dropout: float
    device: str = "auto"
    early_patience: int = 0  # 0 disables early stopping
    scheduler_gamma: float = 0.0  # 0 disables StepLR
    sigma_aug: float = 0.25

    def device_obj(self) -> torch.device:
        if self.device == "cpu":
            return torch.device("cpu")
        if self.device == "cuda":
            return torch.device("cuda")
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def make_run_dir(cfg: TrainConfig) -> str:
    run_dir = os.path.join(cfg.outdir, f"{ts()}_{cfg.dataset}_{cfg.mode}_seed{cfg.seed}_dim{cfg.dim}")
    ensure_dir(run_dir)
    with open(os.path.join(run_dir, "config.json"), "w") as f:
        json.dump(asdict(cfg), f, indent=2)
    return run_dir


def train_run(cfg: TrainConfig) -> str:
    set_seed(cfg.seed)
    device = cfg.device_obj()
    run_dir = make_run_dir(cfg)

    # Data
    dc = DataConfig(dataset=cfg.dataset, dim=cfg.dim, n_classes=cfg.n_classes)
    X, y = make_dataset(cfg.train_size + 4096, dc, seed=cfg.seed)
    Xtr, ytr, Xval, yval, Xte, yte = split_dataset(X, y, cfg.val_ratio, cfg.test_ratio, seed=cfg.seed + 1)
    Xtr_t, ytr_t = to_tensor(Xtr, ytr, device)
    Xval_t, yval_t = to_tensor(Xval, yval, device)

    model = MLP(cfg.dim, cfg.n_classes, hidden=cfg.hidden, depth=cfg.depth, dropout=cfg.dropout).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=max(1, cfg.steps // 3), gamma=cfg.scheduler_gamma) \
        if cfg.scheduler_gamma and cfg.scheduler_gamma > 0 else None

    csv_path = os.path.join(run_dir, "metrics.csv")
    with open(csv_path, "w", newline="") as f:
        csv.writer(f).writerow(["step", "loss", "sup", "ssl", "entropy", "val_acc", "lr"])

    best_val = -1.0
    best_path = os.path.join(run_dir, "model.pt")
    n = Xtr_t.shape[0]
    patience_left = cfg.early_patience

    for step in range(1, cfg.steps + 1):
        model.train()
        idx = np.random.randint(0, n, size=(cfg.batch_size,))
        xb = Xtr_t[idx]
        yb = ytr_t[idx]

        logits = model(xb)

        loss_sup = torch.tensor(0.0, device=device)
        loss_ssl = torch.tensor(0.0, device=device)

        if cfg.mode in ("supervised", "hybrid"):
            y1h = one_hot(cfg.n_classes, yb)
            ysoft = label_smooth(y1h, cfg.label_sharpen)
            loss_sup = cross_entropy_soft_targets(logits, ysoft)

        if cfg.mode in ("self", "hybrid"):
            xa = augment_noise(xb, cfg.sigma_aug)
            xb2 = augment_noise(xb, cfg.sigma_aug)
            la = model(xa)
            lb = model(xb2)
            loss_ssl = consistency_kl(la, lb)

        ent = entropy_from_logits(logits)
        total = loss_sup + loss_ssl - cfg.entropy_bonus * ent

        opt.zero_grad(set_to_none=True)
        total.backward()
        opt.step()
        if sched:
            sched.step()

        if step % cfg.log_every == 0 or step in (1, cfg.steps):
            model.eval()
            with torch.no_grad():
                val_acc = accuracy(model(Xval_t), yval_t)
            lr_current = next(iter(opt.param_groups))["lr"]
            with open(csv_path, "a", newline="") as f:
                csv.writer(f).writerow([
                    step,
                    f"{float(total):.6f}",
                    f"{float(loss_sup):.6f}",
                    f"{float(loss_ssl):.6f}",
                    f"{float(ent):.6f}",
                    f"{val_acc:.4f}",
                    f"{lr_current:.6e}",
                ])
            print(f"[{step:5d}/{cfg.steps}] {cfg.dataset}/{cfg.mode} "
                  f"loss={float(total):.4f} sup={float(loss_sup):.4f} ssl={float(loss_ssl):.4f} "
                  f"ent={float(ent):.4f} val_acc={val_acc:.3f} lr={lr_current:.2e}")

            # Early stopping on best val_acc
            if val_acc > best_val + 1e-6:
                best_val = val_acc
                torch.save({
                    "model_state": model.state_dict(),
                    "in_dim": cfg.dim,
                    "n_classes": cfg.n_classes,
                    "config": asdict(cfg),
                    "dataset": cfg.dataset,
                }, best_path)
                patience_left = cfg.early_patience
            else:
                if cfg.early_patience > 0:
                    patience_left -= 1
                    if patience_left <= 0:
                        print(f"Early stopping at step {step} (best val_acc={best_val:.4f}).")
                        break

    # Optional plots
    if _HAS_MPL:
        try_plot_metrics(csv_path, os.path.join(run_dir, "metrics.png"))
        if cfg.dim == 2:
            try:
                payload = torch.load(best_path, map_location=device)
                model.load_state_dict(payload["model_state"])
            except Exception:
                pass
            X_comb = np.vstack([Xtr[:1000], Xval[:1000]])
            y_comb = np.concatenate([ytr[:1000], yval[:1000]], axis=0)
            try_plot_boundary(model, X_comb, y_comb, os.path.join(run_dir, "boundary.png"), device)

    print(f"Done. Outputs saved to: {run_dir}")
    return run_dir


@torch.no_grad()
def eval_run(ckpt_path: str, dataset: Optional[str] = None, dim: Optional[int] = None,
             n_samples: int = 4096, seed: Optional[int] = None) -> float:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    payload = torch.load(ckpt_path, map_location=device)
    in_dim = payload.get("in_dim")
    n_classes = payload.get("n_classes")
    cfg = payload.get("config", {})
    dataset_name = dataset or payload.get("dataset", cfg.get("dataset", "blobs"))
    in_dim = dim or in_dim
    if in_dim is None or n_classes is None:
        raise ValueError("Checkpoint missing in_dim/n_classes metadata.")

    model = MLP(in_dim, n_classes, hidden=cfg.get("hidden", 64),
                depth=cfg.get("depth", 2), dropout=cfg.get("dropout", 0.0)).to(device)
    model.load_state_dict(payload["model_state"])
    model.eval()

    dc = DataConfig(dataset=dataset_name, dim=in_dim, n_classes=n_classes)
    test_seed = (cfg.get("seed", 42) + 999) if seed is None else seed
    X, y = make_dataset(n_samples, dc, seed=test_seed)
    Xt, yt = to_tensor(X, y, device)
    logits = model(Xt)
    acc = accuracy(logits, yt)
    print(f"Eval accuracy ({dataset_name}, dim={in_dim}, n={n_samples}): {acc:.4f}")

    # Plot boundary if 2D
    if _HAS_MPL and in_dim == 2:
        try_plot_boundary(model, X, y, os.path.join(os.path.dirname(ckpt_path), "boundary_eval.png"), device)

    return acc


@torch.no_grad()
def predict_run(ckpt_path: str, points_csv: Optional[str], points_inline: Optional[str]) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    payload = torch.load(ckpt_path, map_location=device)
    in_dim = payload["in_dim"]
    n_classes = payload["n_classes"]
    cfg = payload.get("config", {})

    model = MLP(in_dim, n_classes, hidden=cfg.get("hidden", 64),
                depth=cfg.get("depth", 2), dropout=cfg.get("dropout", 0.0)).to(device)
    model.load_state_dict(payload["model_state"])
    model.eval()

    pts: List[List[float]] = []
    if points_csv:
        with open(points_csv, "r", newline="") as f:
            r = csv.reader(f)
            for row in r:
                if not row:
                    continue
                vals = [float(t) for t in row]
                if len(vals) != in_dim:
                    raise ValueError(f"Each CSV row must have {in_dim} values, got {len(vals)}.")
                pts.append(vals)
    elif points_inline:
        # format: "x1,x2; y1,y2; ..."
        for chunk in points_inline.split(";"):
            chunk = chunk.strip()
            if not chunk:
                continue
            row = [float(t) for t in chunk.split(",") if t.strip() != ""]
            if len(row) != in_dim:
                raise ValueError(f"Each point must have {in_dim} values.")
            pts.append(row)
    else:
        raise ValueError("Provide --points-csv or --points-inline.")

    Xt = torch.tensor(np.array(pts, dtype=np.float32), device=device)
    logits = model(Xt)
    probs = logits.softmax(dim=-1).cpu().numpy()
    preds = probs.argmax(axis=-1).tolist()

    out = [{"x": p, "pred": int(c), "probs": [float(q) for q in pr]} for p, c, pr in zip(pts, preds, probs)]
    print(json.dumps(out, indent=2))


def export_onnx(ckpt_path: str, out_path: Optional[str] = None) -> None:
    if not _HAS_ONNX:
        print("ONNX not available. Install onnx to enable export.")
        return
    device = torch.device("cpu")
    payload = torch.load(ckpt_path, map_location=device)
    in_dim = payload["in_dim"]
    n_classes = payload["n_classes"]
    cfg = payload.get("config", {})

    model = MLP(in_dim, n_classes, hidden=cfg.get("hidden", 64),
                depth=cfg.get("depth", 2), dropout=0.0).to(device)
    model.load_state_dict(payload["model_state"])
    model.eval()

    dummy = torch.zeros(1, in_dim, dtype=torch.float32, device=device)
    out_path = out_path or os.path.join(os.path.dirname(ckpt_path), "model.onnx")
    torch.onnx.export(
        model, dummy, out_path,
        input_names=["input"], output_names=["logits"],
        dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
        opset_version=13,
    )
    print(f"Exported ONNX to: {out_path}")


# --- Plotting helpers -----------------------------------------------------------

def try_plot_metrics(csv_path: str, out_png: str) -> None:
    if not os.path.exists(csv_path):
        return
    steps, loss, sup, ssl, ent, acc, lr = [], [], [], [], [], [], []
    with open(csv_path, "r") as f:
        reader = csv.DictReader(f)
        for r in reader:
            steps.append(int(r["step"]))
            loss.append(float(r["loss"]))
            sup.append(float(r["sup"]))
            ssl.append(float(r["ssl"]))
            ent.append(float(r["entropy"]))
            acc.append(float(r["val_acc"]))
            lr.append(float(r.get("lr", 0.0)))
    if not steps:
        return
    fig, ax1 = plt.subplots(1, 1, figsize=(8, 4.5))
    ax1.plot(steps, loss, label="total", color="#1f77b4")
    ax1.plot(steps, sup, label="supervised", color="#ff7f0e")
    ax1.plot(steps, ssl, label="self", color="#2ca02c")
    ax1.plot(steps, ent, label="entropy", color="#9467bd")
    ax1.plot(steps, acc, label="val_acc", color="#d62728")
    ax1.set_title("Training metrics")
    ax1.set_xlabel("step")
    ax1.grid(True, alpha=0.3)
    ax1.legend(loc="best")
    if any(lr):
        ax2 = ax1.twinx()
        ax2.plot(steps, lr, label="lr", color="#8c564b", linestyle="--", alpha=0.6)
        ax2.set_ylabel("learning rate")
    fig.tight_layout()
    fig.savefig(out_png, dpi=140)
    plt.close(fig)


def tryplotboundary(model: nn.Module, X: np.ndarray, y: np.ndarray, out_png: str, device: torch.device) -> None:
    """2D decision boundary + points."""
    if X.shape[1] != 2:
        return
    xmin, xmax = X[:, 0].min() - 1.5, X[:, 0].max() + 1.5
    ymin, ymax = X[:, 1].min() - 1.5, X[:, 1].max() + 1.5
    xx, yy = np.meshgrid(np.linspace(xmin, xmax, 400), np.linspace(ymin, ymax, 400))
    grid = np.c_[xx.ravel(), yy.ravel()].astype(np.float32)
    with torch.no_grad():
        logits = model(torch.tensor(grid, device=device))
        Z = logits.softmax(dim=-1)[:, 1].cpu().numpy().reshape(xx.shape)
    fig, ax = plt.subplots(1, 1, figsize=(6.5, 5.5))
    cs = ax.contourf(xx, yy, Z, levels=30, cmap="coolwarm", alpha=0.85)
    fig.colorbar(cs, ax=ax, label="P(class=1)")
    ax.scatter(X[:, 0], X[:, 1], c=y, cmap="bwr", edgecolor="k", s=12, alpha=0.8)
    ax.set_title("Decision boundary")
    ax.set_xlabel("x1")
    ax.set_ylabel("y1")
    ax.grid(True, alpha=0.2)
    fig.tight_layout()
    fig.savefig(out_png, dpi=140)
    plt.close(fig)


# --- Benchmark ------------------------------------------------------------------

@dataclass
class BenchmarkCase:
    mode: str
    dataset: str
    dim: int
    seed: int


def benchmarkrun(outcsv: str,
                  modes: List[str],
                  datasets: List[str],
                  dims: List[int],
                  seeds: List[int],
                  base_cfg: Dict) -> None:
    rows = []
    for ds in datasets:
        for mode in modes:
            for d in dims:
                for s in seeds:
                    cfg = TrainConfig(
                        mode=mode,
                        steps=int(base_cfg.get("steps", 300)),
                        batchsize=int(basecfg.get("batch_size", 128)),
                        lr=float(base_cfg.get("lr", 1e-3)),
                        wd=float(base_cfg.get("wd", 0.0)),
                        entropybonus=float(basecfg.get("entropy_bonus", 0.0)),
                        labelsharpen=float(basecfg.get("label_sharpen", 0.0)),
                        seed=s,
                        outdir=str(base_cfg.get("outdir", "runs")),
                        logevery=int(basecfg.get("log_every", 100)),
                        dim=d,
                        nclasses=int(basecfg.get("n_classes", 2)),
                        trainsize=int(basecfg.get("train_size", 8192)),
                        valratio=float(basecfg.get("val_ratio", 0.15)),
                        testratio=float(basecfg.get("test_ratio", 0.15)),
                        dataset=ds,
                        hidden=int(base_cfg.get("hidden", 64)),
                        depth=int(base_cfg.get("depth", 2)),
                        dropout=float(base_cfg.get("dropout", 0.0)),
                        device=str(base_cfg.get("device", "auto")),
                        earlypatience=int(basecfg.get("early_patience", 0)),
                        schedulergamma=float(basecfg.get("scheduler_gamma", 0.0)),
                        sigmaaug=float(basecfg.get("sigma_aug", 0.25)),
                    )
                    print(f"=== Benchmark: {ds} / {mode} / dim={d} / seed={s} ===")
                    rundir = trainrun(cfg)
                    ckpt = os.path.join(run_dir, "model.pt")
                    acc = evalrun(ckpt, dataset=ds, dim=d, nsamples=int(basecfg.get("benchn", 4096)),
                                   seed=s + 12345)
                    rows.append({
                        "dataset": ds, "mode": mode, "dim": d, "seed": s,
                        "rundir": rundir, "valbestacc": acc
                    })

    ensuredir(os.path.dirname(outcsv) or ".")
    with open(out_csv, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=["dataset", "mode", "dim", "seed", "rundir", "valbest_acc"])
        w.writeheader()
        for r in rows:
            w.writerow(r)
    print(f"Benchmark summary saved to: {out_csv}")


# --- CLI -----------------------------------------------------------------------

def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(description="MetaIntelligence end-to-end toolkit (notebook/Colab-safe)")
    sub = p.add_subparsers(dest="cmd")

    # Train
    t = sub.add_parser("train", help="Train a model (self/supervised/hybrid)")
    t.add_argument("--mode", choices=["self", "supervised", "hybrid"], default="supervised")
    t.add_argument("--dataset", choices=["blobs", "moons", "circles"], default="blobs")
    t.add_argument("--steps", type=int, default=500)
    t.add_argument("--batch-size", type=int, default=128)
    t.add_argument("--lr", type=float, default=1e-3)
    t.add_argument("--wd", type=float, default=0.0)
    t.add_argument("--entropy-bonus", type=float, default=0.0)
    t.add_argument("--label-sharpen", type=float, default=0.0)
    t.add_argument("--seed", type=int, default=42)
    t.add_argument("--outdir", type=str, default="runs")
    t.add_argument("--log-every", type=int, default=50)
    t.add_argument("--dim", type=int, default=2)
    t.add_argument("--n-classes", type=int, default=2)
    t.add_argument("--train-size", type=int, default=8192)
    t.add_argument("--val-ratio", type=float, default=0.15)
    t.add_argument("--test-ratio", type=float, default=0.15)
    t.add_argument("--hidden", type=int, default=64)
    t.add_argument("--depth", type=int, default=2)
    t.add_argument("--dropout", type=float, default=0.0)
    t.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda"])
    t.add_argument("--early-patience", type=int, default=0, help="Early stopping patience (0 disables)")
    t.add_argument("--scheduler-gamma", type=float, default=0.0, help="StepLR gamma (0 disables)")
    t.add_argument("--sigma-aug", type=float, default=0.25, help="Noise sigma for self/hybrid consistency")

    # Eval
    e = sub.add_parser("eval", help="Evaluate a saved checkpoint on a fresh synthetic test set")
    e.add_argument("--ckpt", type=str, required=True, help="Path to model.pt")
    e.add_argument("--dataset", type=str, default=None, help="Override dataset used for eval")
    e.add_argument("--dim", type=int, default=None, help="Override input dim used for eval")
    e.add_argument("--n", type=int, default=4096, help="Number of samples for eval")
    e.add_argument("--seed", type=int, default=None, help="Random seed for eval set")

    # Predict
    p2 = sub.add_parser("predict", help="Predict for given points")
    p2.add_argument("--ckpt", type=str, required=True, help="Path to model.pt")
    p2.add_argument("--points-csv", type=str, default=None, help="CSV file without header, each row is a point")
    p2.add_argument("--points-inline", type=str, default=None,
                    help='Inline points, e.g., "x1,x2; y1,y2" (must match input dim)')

    # Benchmark
    b = sub.add_parser("benchmark", help="Run grid of modes/datasets/dims across seeds and summarize")
    b.add_argument("--modes", type=str, default="supervised,hybrid", help="Comma list of modes")
    b.add_argument("--datasets", type=str, default="blobs,moons,circles", help="Comma list of datasets")
    b.add_argument("--dims", type=str, default="2", help="Comma list of dims (e.g., '2,4')")
    b.add_argument("--seeds", type=str, default="1,2,3", help="Comma list of seeds")
    b.add_argument("--summary", type=str, default="runs/benchmarksummary.csv", help="Output CSV")
    b.add_argument("--steps", type=int, default=300)
    b.add_argument("--batch-size", type=int, default=128)
    b.add_argument("--lr", type=float, default=1e-3)
    b.add_argument("--wd", type=float, default=0.0)
    b.add_argument("--entropy-bonus", type=float, default=0.0)
    b.add_argument("--label-sharpen", type=float, default=0.0)
    b.add_argument("--outdir", type=str, default="runs")
    b.add_argument("--log-every", type=int, default=100)
    b.add_argument("--n-classes", type=int, default=2)
    b.add_argument("--train-size", type=int, default=8192)
    b.add_argument("--val-ratio", type=float, default=0.15)
    b.add_argument("--test-ratio", type=float, default=0.15)
    b.add_argument("--hidden", type=int, default=64)
    b.add_argument("--depth", type=int, default=2)
    b.add_argument("--dropout", type=float, default=0.0)
    b.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda"])
    b.add_argument("--early-patience", type=int, default=0)
    b.add_argument("--scheduler-gamma", type=float, default=0.0)
    b.add_argument("--sigma-aug", type=float, default=0.25)
    b.add_argument("--bench-n", type=int, default=4096, help="Samples per eval during benchmark")

    # ONNX export
    o = sub.add_parser("export-onnx", help="Export a trained checkpoint to ONNX")
    o.add_argument("--ckpt", type=str, required=True, help="Path to model.pt")
    o.add_argument("--out", type=str, default=None, help="Output .onnx path")

    return p


def main(argv: Optional[List[str]] = None) -> None:
    argv = sanitize_argv(argv)
    parser = build_parser()
    if not argv:
        parser.print_help()
        return
    args,  = parser.parseknown_args(argv)

    if args.cmd == "train":
        cfg = TrainConfig(
            mode=args.mode,
            steps=args.steps,
            batchsize=args.batchsize,
            lr=args.lr,
            wd=args.wd,
            entropybonus=args.entropybonus,
            labelsharpen=args.labelsharpen,
            seed=args.seed,
            outdir=args.outdir,
            logevery=args.logevery,
            dim=args.dim,
            nclasses=args.nclasses,
            trainsize=args.trainsize,
            valratio=args.valratio,
            testratio=args.testratio,
            dataset=args.dataset,
            hidden=args.hidden,
            depth=args.depth,
            dropout=args.dropout,
            device=args.device,
            earlypatience=args.earlypatience,
            schedulergamma=args.schedulergamma,
            sigmaaug=args.sigmaaug,
        )
        train_run(cfg)

    elif args.cmd == "eval":
        evalrun(args.ckpt, dataset=args.dataset, dim=args.dim, nsamples=args.n, seed=args.seed)

    elif args.cmd == "predict":
        predictrun(args.ckpt, args.pointscsv, args.points_inline)

    elif args.cmd == "benchmark":
        modes = [m.strip() for m in args.modes.split(",") if m.strip()]
        datasets = [d.strip() for d in args.datasets.split(",") if d.strip()]
        dims = [int(x) for x in args.dims.split(",") if x.strip()]
        seeds = [int(x) for x in args.seeds.split(",") if x.strip()]
        base_cfg = dict(
            steps=args.steps, batchsize=args.batchsize, lr=args.lr, wd=args.wd,
            entropybonus=args.entropybonus, labelsharpen=args.labelsharpen,
            outdir=args.outdir, logevery=args.logevery, nclasses=args.nclasses,
            trainsize=args.trainsize, valratio=args.valratio, testratio=args.testratio,
            hidden=args.hidden, depth=args.depth, dropout=args.dropout, device=args.device,
            earlypatience=args.earlypatience, schedulergamma=args.schedulergamma,
            sigmaaug=args.sigmaaug, benchn=args.benchn
        )
        benchmarkrun(args.summary, modes, datasets, dims, seeds, basecfg)

    elif args.cmd == "export-onnx":
        export_onnx(args.ckpt, args.out)

    else:
        parser.print_help()


if __name__ == "__main__":
    try:
        main()
    except SystemExit:
        print("Use 'exit', 'quit', or Ctrl-D to exit.", file=sys.stderr)
        raise