<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/MetaIntelligence_end_to_end_toolkit_(notebook_Colab_safe).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# MetaIntelligence end-to-end toolkit (notebook/Colab-safe)
# - Train (supervised/selfsup/hybrid)
# - Eval on fresh synthetic data
# - Predict on points/CSV
# - Benchmark grid across modes/datasets/seeds
# - Export ONNX
# Safe for notebooks: main([]) prints help; plotting uses Agg backend

import argparse
from typing import Optional, List, Tuple, Dict, Any
import os
import sys
import json
import time
import random
from dataclasses import dataclass
import numpy as np

# Safe headless plotting
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_moons, make_circles, make_blobs, make_classification
from sklearn.metrics import accuracy_score


# ----------------------------
# Utilities
# ----------------------------

def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def get_device(force_cpu: bool = False) -> torch.device:
    if force_cpu:
        return torch.device("cpu")
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def ensure_dir(path: str) -> None:
    if path and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)


def now_ts() -> str:
    return time.strftime("%Y%m%d-%H%M%S")


def sanitize_filename(name: str) -> str:
    return "".join(ch for ch in name if ch.isalnum() or ch in "-_.")


# ----------------------------
# Data generation
# ----------------------------

@dataclass
class DataConfig:
    dataset: str
    dims: int
    classes: int
    samples: int
    noise: float
    valfrac: float
    seed: int


def gen_synthetic(cfg: DataConfig) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    rng = np.random.RandomState(cfg.seed)
    n_train = int((1.0 - cfg.valfrac) * cfg.samples)
    n_val = cfg.samples - n_train

    if cfg.dataset == "moons":
        if cfg.dims != 2:
            raise ValueError("moons requires --dims=2")
        X, y = make_moons(n_samples=cfg.samples, noise=cfg.noise, random_state=cfg.seed)
    elif cfg.dataset == "circles":
        if cfg.dims != 2:
            raise ValueError("circles requires --dims=2")
        X, y = make_circles(n_samples=cfg.samples, noise=cfg.noise, factor=0.5, random_state=cfg.seed)
    elif cfg.dataset == "blobs":
        X, y = make_blobs(
            n_samples=cfg.samples, centers=cfg.classes, n_features=cfg.dims,
            cluster_std=max(cfg.noise, 0.05), random_state=cfg.seed
        )
    elif cfg.dataset == "classification":
        X, y = make_classification(
            n_samples=cfg.samples,
            n_features=cfg.dims,
            n_informative=max(2, min(cfg.dims, cfg.dims - 0)),
            n_redundant=0,
            n_repeated=0,
            n_classes=cfg.classes,
            n_clusters_per_class=1,
            flip_y=min(cfg.noise, 0.4),
            class_sep=max(0.5, 2.0 - cfg.noise),
            random_state=cfg.seed,
        )
    else:
        raise ValueError(f"Unknown dataset: {cfg.dataset}")

    # Shuffle and split
    idx = np.arange(cfg.samples)
    rng.shuffle(idx)
    X = X[idx]
    y = y[idx]
    X_train, y_train = X[:n_train], y[:n_train]
    X_val, y_val = X[n_train:], y[n_train:]

    # Standardize features for stability
    mu = X_train.mean(axis=0, keepdims=True)
    sigma = X_train.std(axis=0, keepdims=True) + 1e-8
    X_train = (X_train - mu) / sigma
    X_val = (X_val - mu) / sigma

    return X_train.astype(np.float32), y_train.astype(np.int64), X_val.astype(np.float32), y_val.astype(np.int64)


def make_loaders(
    X_train: np.ndarray, y_train: np.ndarray,
    X_val: np.ndarray, y_val: np.ndarray,
    batch_size: int,
    labeled_frac: float = 1.0,
    seed: int = 42
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
    # Full supervised loader
    tds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
    vds = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))

    if labeled_frac >= 1.0:
        return (
            DataLoader(tds, batch_size=batch_size, shuffle=True, drop_last=False),
            DataLoader(vds, batch_size=batch_size, shuffle=False, drop_last=False),
            None
        )

    # Split labeled/unlabeled
    n = len(tds)
    n_lab = max(1, int(labeled_frac * n))
    rng = np.random.RandomState(seed)
    perm = rng.permutation(n)
    lab_idx = perm[:n_lab]
    unlab_idx = perm[n_lab:]

    X_lab = torch.from_numpy(X_train[lab_idx])
    y_lab = torch.from_numpy(y_train[lab_idx])
    X_unlab = torch.from_numpy(X_train[unlab_idx])

    lab_loader = DataLoader(TensorDataset(X_lab, y_lab), batch_size=batch_size, shuffle=True, drop_last=False)
    unlab_loader = DataLoader(TensorDataset(X_unlab, torch.zeros(len(X_unlab)).long()), batch_size=batch_size, shuffle=True, drop_last=False)
    val_loader = DataLoader(vds, batch_size=batch_size, shuffle=False, drop_last=False)
    return lab_loader, val_loader, unlab_loader


# ----------------------------
# Model
# ----------------------------

class MLP(nn.Module):
    def __init__(self, in_dim: int, num_classes: int, width: int = 64, depth: int = 2, dropout: float = 0.0):
        super().__init__()
        layers = []
        prev = in_dim
        for _ in range(max(0, depth)):
            layers.append(nn.Linear(prev, width))
            layers.append(nn.ReLU(inplace=True))
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            prev = width
        self.backbone = nn.Sequential(*layers) if layers else nn.Identity()
        self.head = nn.Linear(prev, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.backbone(x)
        return self.head(h)


# ----------------------------
# Training and evaluation
# ----------------------------

def train_supervised(
    model: nn.Module,
    lab_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    epochs: int,
    lr: float,
    weight_decay: float
) -> Dict[str, Any]:
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    best = {"val_acc": -1.0, "state_dict": None, "epoch": -1}
    history = []

    for epoch in range(1, epochs + 1):
        model.train()
        losses = []
        for xb, yb in lab_loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            logits = model(xb)
            loss = F.cross_entropy(logits, yb)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            losses.append(loss.item())

        val_acc = evaluate_accuracy(model, val_loader, device)
        avg_loss = float(np.mean(losses)) if losses else 0.0
        history.append({"epoch": epoch, "train_loss": avg_loss, "val_acc": val_acc})

        if val_acc > best["val_acc"]:
            best = {"val_acc": val_acc, "state_dict": {k: v.cpu() for k, v in model.state_dict().items()}, "epoch": epoch}

    # Load best state before returning
    if best["state_dict"] is not None:
        model.load_state_dict(best["state_dict"], strict=True)
    return {"best_val_acc": best["val_acc"], "best_epoch": best["epoch"], "history": history}


def train_selfsup_or_hybrid(
    mode: str,
    model: nn.Module,
    lab_loader: DataLoader,
    unlab_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    epochs: int,
    lr: float,
    weight_decay: float,
    warmup_epochs: int,
    pseudo_thresh: float,
    lambda_consistency: float,
    aug_noise: float
) -> Dict[str, Any]:
    assert mode in ("selfsup", "hybrid")
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    best = {"val_acc": -1.0, "state_dict": None, "epoch": -1}
    history = []

    # Cycle through unlabeled loader independently
    unlab_iter = None

    for epoch in range(1, epochs + 1):
        model.train()
        losses = []

        # Warmup: supervised on labeled only
        if epoch <= warmup_epochs:
            for xb, yb in lab_loader:
                xb = xb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True)
                logits = model(xb)
                loss = F.cross_entropy(logits, yb)
                opt.zero_grad(set_to_none=True)
                loss.backward()
                opt.step()
                losses.append(loss.item())
        else:
            # Train over labeled batches; add pseudo-labels + consistency from unlabeled
            if unlab_iter is None:
                unlab_iter = iter(unlab_loader)
            for xb_l, yb_l in lab_loader:
                xb_l = xb_l.to(device, non_blocking=True)
                yb_l = yb_l.to(device, non_blocking=True)

                # Supervised loss on labeled (only if hybrid or we keep a small anchor in selfsup)
                logits_l = model(xb_l)
                loss_sup = F.cross_entropy(logits_l, yb_l) if mode == "hybrid" else 0.0

                # Fetch an unlabeled batch (recycle iterator)
                try:
                    xb_u, _ = next(unlab_iter)
                except StopIteration:
                    unlab_iter = iter(unlab_loader)
                    xb_u, _ = next(unlab_iter)
                xb_u = xb_u.to(device, non_blocking=True)

                with torch.no_grad():
                    logits_u = model(xb_u)
                    probs_u = F.softmax(logits_u, dim=-1)
                    conf, pseudo = probs_u.max(dim=-1)
                    mask = conf >= pseudo_thresh

                # Pseudo-label supervised loss
                if mask.any():
                    logits_u_mask = model(xb_u[mask])
                    loss_pseudo = F.cross_entropy(logits_u_mask, pseudo[mask])
                else:
                    loss_pseudo = 0.0

                # Consistency regularization
                noise = torch.randn_like(xb_u) * aug_noise
                logits_orig = model(xb_u)
                logits_aug = model(xb_u + noise)
                # KL divergence between distributions
                p = F.softmax(logits_orig.detach(), dim=-1)
                log_q = F.log_softmax(logits_aug, dim=-1)
                loss_cons = F.kl_div(log_q, p, reduction="batchmean")

                # Total loss
                total = (loss_sup if isinstance(loss_sup, torch.Tensor) else torch.tensor(0.0, device=device))
                total = total + (loss_pseudo if isinstance(loss_pseudo, torch.Tensor) else torch.tensor(0.0, device=device))
                total = total + lambda_consistency * loss_cons

                opt.zero_grad(set_to_none=True)
                total.backward()
                opt.step()
                losses.append(float(total.item()))

        val_acc = evaluate_accuracy(model, val_loader, device)
        avg_loss = float(np.mean(losses)) if losses else 0.0
        history.append({"epoch": epoch, "train_loss": avg_loss, "val_acc": val_acc})

        if val_acc > best["val_acc"]:
            best = {"val_acc": val_acc, "state_dict": {k: v.cpu() for k, v in model.state_dict().items()}, "epoch": epoch}

    if best["state_dict"] is not None:
        model.load_state_dict(best["state_dict"], strict=True)
    return {"best_val_acc": best["val_acc"], "best_epoch": best["epoch"], "history": history}


@torch.no_grad()
def evaluate_accuracy(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
    model.eval().to(device)
    preds = []
    gts = []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        logits = model(xb)
        yhat = logits.argmax(dim=-1).cpu().numpy()
        preds.append(yhat)
        gts.append(yb.numpy())
    y_pred = np.concatenate(preds) if preds else np.array([], dtype=np.int64)
    y_true = np.concatenate(gts) if gts else np.array([], dtype=np.int64)
    if len(y_true) == 0:
        return 0.0
    return float(accuracy_score(y_true, y_pred))


def save_checkpoint(
    path: str,
    model: nn.Module,
    train_cfg: Dict[str, Any],
    data_cfg: Dict[str, Any],
    metrics: Dict[str, Any]
) -> None:
    ensure_dir(os.path.dirname(path))
    payload = {
        "meta": {
            "timestamp": now_ts(),
            "framework": "torch",
            "task": "synthetic-classification",
        },
        "model": {
            "state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
            "in_dim": train_cfg["in_dim"],
            "num_classes": train_cfg["num_classes"],
            "width": train_cfg["width"],
            "depth": train_cfg["depth"],
            "dropout": train_cfg["dropout"],
        },
        "data_cfg": data_cfg,
        "train_cfg": train_cfg,
        "metrics": metrics,
    }
    torch.save(payload, path)


def load_model_from_ckpt(ckpt_path: str, device: torch.device) -> Tuple[nn.Module, Dict[str, Any]]:
    ckpt = torch.load(ckpt_path, map_location=device)
    mcfg = ckpt["model"]
    model = MLP(
        in_dim=int(mcfg["in_dim"]),
        num_classes=int(mcfg["num_classes"]),
        width=int(mcfg["width"]),
        depth=int(mcfg["depth"]),
        dropout=float(mcfg["dropout"]),
    )
    model.load_state_dict({k: v for k, v in mcfg["state_dict"].items()}, strict=True)
    model.to(device).eval()
    return model, ckpt


def plot_decision_boundary(
    model: nn.Module,
    X: np.ndarray,
    y: np.ndarray,
    device: torch.device,
    title: str,
    save_path: Optional[str] = None
) -> Optional[str]:
    if X.shape[1] != 2:
        return None
    x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
    y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
    xx, yy = np.meshgrid(
        np.linspace(x_min, x_max, 300),
        np.linspace(y_min, y_max, 300)
    )
    grid = np.c_[xx.ravel(), yy.ravel()].astype(np.float32)
    with torch.no_grad():
        logits = model(torch.from_numpy(grid).to(device))
        Z = logits.argmax(dim=-1).cpu().numpy().reshape(xx.shape)

    plt.figure(figsize=(6, 5))
    plt.contourf(xx, yy, Z, alpha=0.3, cmap="coolwarm")
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap="coolwarm", s=12, edgecolors="k", linewidths=0.2)
    plt.title(title)
    plt.tight_layout()
    if save_path:
        ensure_dir(os.path.dirname(save_path))
        plt.savefig(save_path, dpi=150)
        plt.close()
        return save_path
    return None


# ----------------------------
# Command handlers
# ----------------------------

def cmdtrain(args) -> None:
    set_seed(args.seed)
    device = get_device(args.cpu)

    data_cfg = DataConfig(
        dataset=args.dataset,
        dims=args.dims,
        classes=args.classes,
        samples=args.samples,
        noise=args.noise,
        valfrac=args.valfrac,
        seed=args.seed,
    )
    Xtr, ytr, Xval, yval = gensynthetic(datacfg)
    labloader, valloader, unlabloader = makeloaders(
        Xtr, ytr, Xval, yval,
        batchsize=args.batchsize,
        labeledfrac=args.labeledfrac if args.mode in ("selfsup", "hybrid") else 1.0,
        seed=args.seed
    )

    model = MLP(
        indim=datacfg.dims,
        numclasses=datacfg.classes,
        width=args.width,
        depth=args.depth,
        dropout=args.dropout
    )

    if args.mode == "supervised" or unlab_loader is None:
        result = train_supervised(
            model, labloader, valloader, device,
            epochs=args.epochs, lr=args.lr, weightdecay=args.weightdecay
        )
    else:
        result = trainselfsupor_hybrid(
            args.mode, model, labloader, unlabloader, val_loader, device,
            epochs=args.epochs, lr=args.lr, weightdecay=args.weightdecay,
            warmupepochs=args.warmupepochs, pseudothresh=args.pseudothresh,
            lambdaconsistency=args.lambdaconsistency, augnoise=args.augnoise
        )

    # Final eval on validation
    valacc = evaluateaccuracy(model, val_loader, device)

    # Save outputs
    ensure_dir(args.out)
    ckptpath = os.path.join(args.out, f"checkpoint{sanitizefilename(args.dataset)}{now_ts()}.pt")
    train_cfg = {
        "mode": args.mode,
        "epochs": args.epochs,
        "batchsize": args.batchsize,
        "lr": args.lr,
        "weightdecay": args.weightdecay,
        "width": args.width,
        "depth": args.depth,
        "dropout": args.dropout,
        "indim": datacfg.dims,
        "numclasses": datacfg.classes,
        "seed": args.seed,
        "labeledfrac": args.labeledfrac,
        "pseudothresh": args.pseudothresh,
        "warmupepochs": args.warmupepochs,
        "lambdaconsistency": args.lambdaconsistency,
        "augnoise": args.augnoise,
    }
    metrics = {
        "bestvalacc": result["bestvalacc"],
        "bestepoch": result["bestepoch"],
        "finalvalacc": val_acc,
        "history": result["history"],
    }
    savecheckpoint(ckptpath, model, traincfg, datacfg.dict, metrics)

    # Optional plot for 2D data
    fig_path = None
    if args.plot and data_cfg.dims == 2:
        figpath = os.path.join(args.out, f"decisionboundary{sanitizefilename(args.dataset)}.png")
        plotdecisionboundary(model, Xtr, ytr, device, f"Train ({args.dataset})", fig_path)

    # Write summary JSON
    summary = {
        "checkpoint": ckpt_path,
        "dataset": datacfg.dict_,
        "traincfg": traincfg,
        "metrics": metrics,
        "plot": fig_path or "",
    }
    with open(os.path.join(args.out, "train_summary.json"), "w") as f:
        json.dump(summary, f, indent=2)

    print(f"[train] saved checkpoint: {ckpt_path}")
    print(f"[train] bestvalacc={metrics['bestvalacc']:.4f} at epoch {metrics['bestepoch']} | finalvalacc={valacc:.4f}")
    if fig_path:
        print(f"[train] saved plot: {fig_path}")


def cmdeval(args) -> None:
    device = get_device(args.cpu)
    model, ckpt = loadmodelfrom_ckpt(args.ckpt, device)

    dcfg = ckpt["data_cfg"]
    # Fresh test set based on training config; new seed for variation
    test_seed = int(dcfg.get("seed", 42)) + 997
    data_cfg = DataConfig(
        dataset=dcfg["dataset"],
        dims=int(dcfg["dims"]),
        classes=int(dcfg["classes"]),
        samples=int(args.samples),
        noise=float(dcfg["noise"]),
        valfrac=0.0,  # not used for test-only
        seed=test_seed,
    )
    Xtr, ytr, Xte, yte = gensynthetic(datacfg)  # Xtr unused; we want Xte from val slot
    Xtest, ytest = Xte, yte
    testloader = DataLoader(TensorDataset(torch.fromnumpy(Xtest), torch.from_numpy(ytest)),
                             batch_size=256, shuffle=False)

    acc = evaluateaccuracy(model, testloader, device)

    out_dir = args.out
    ensuredir(outdir) if out_dir else None

    plot_path = None
    if args.plot and data_cfg.dims == 2:
        if not out_dir:
            out_dir = os.path.dirname(args.ckpt) or "."
            ensuredir(outdir)
        plotpath = os.path.join(outdir, f"evalboundary{now_ts()}.png")
        plotdecisionboundary(model, Xtest, ytest, device, f"Eval ({dcfg['dataset']})", plot_path)

    result = {
        "ckpt": args.ckpt,
        "dataset": datacfg.dict_,
        "accuracy": acc,
        "plot": plot_path or "",
        "timestamp": now_ts(),
    }
    if out_dir:
        with open(os.path.join(outdir, "evalsummary.json"), "w") as f:
            json.dump(result, f, indent=2)

    print(f"[eval] accuracy={acc:.4f} on {len(Xtest)} samples")
    if plot_path:
        print(f"[eval] saved plot: {plot_path}")


def parsepoints(pointsstr: str, dims: int) -> np.ndarray:
    pts = []
    for seg in points_str.split(";"):
        seg = seg.strip()
        if not seg:
            continue
        comps = [c.strip() for c in seg.split(",")]
        if len(comps) != dims:
            raise ValueError(f"Point '{seg}' does not match dims={dims}")
        pts.append([float(x) for x in comps])
    if not pts:
        raise ValueError("No valid points parsed from --points")
    return np.asarray(pts, dtype=np.float32)


def cmdpredict(args) -> None:
    device = get_device(args.cpu)
    model, ckpt = loadmodelfrom_ckpt(args.ckpt, device)
    dcfg = ckpt["data_cfg"]
    dims = int(dcfg["dims"])

    X = None
    if args.points:
        X = parse_points(args.points, dims)
    elif args.csv:
        import csv
        rows = []
        with open(args.csv, "r", newline="") as f:
            reader = csv.reader(f)
            for row in reader:
                if not row:
                    continue
                vals = [float(x) for x in row]
                rows.append(vals)
        if not rows:
            raise ValueError("CSV is empty or invalid")
        X = np.asarray(rows, dtype=np.float32)
        if X.shape[1] != dims:
            raise ValueError(f"CSV has {X.shape[1]} columns but model expects dims={dims}")
    else:
        raise ValueError("Provide either --points or --csv")

    with torch.no_grad():
        logits = model(torch.from_numpy(X).to(device))
        probs = F.softmax(logits, dim=-1).cpu().numpy()
        yhat = probs.argmax(axis=1)

    out_dir = args.out
    ensuredir(outdir) if out_dir else None

    # Print to console
    for i, (p, pr) in enumerate(zip(yhat, probs)):
        print(f"sample[{i}] -> class={int(p)} prob={pr[int(p)]:.4f} dist={np.array2string(pr, precision=3)}")

    # Optional plot for 2D
    plot_path = None
    if args.plot and X.shape[1] == 2:
        if not out_dir:
            out_dir = os.path.dirname(args.ckpt) or "."
            ensuredir(outdir)
        plt.figure(figsize=(5, 4))
        plt.scatter(X[:, 0], X[:, 1], c=yhat, cmap="coolwarm", s=24, edgecolors="k", linewidths=0.5)
        plt.title("Predictions")
        plt.tight_layout()
        plotpath = os.path.join(outdir, f"predictpoints{now_ts()}.png")
        plt.savefig(plot_path, dpi=150)
        plt.close()
        print(f"[predict] saved plot: {plot_path}")

    # Optional CSV export
    if out_dir:
        outcsv = os.path.join(outdir, f"predictions{nowts()}.csv")
        import csv
        with open(out_csv, "w", newline="") as f:
            writer = csv.writer(f)
            header = [f"x{i+1}" for i in range(X.shape[1])] + ["predclass"] + [f"p{i}" for i in range(probs.shape[1])]
            writer.writerow(header)
            for xi, cls, pr in zip(X, yhat, probs):
                writer.writerow([map(lambda z: f"{z:.6f}", xi.tolist()), int(cls), [f"{v:.6f}" for v in pr.tolist()]])
        print(f"[predict] saved CSV: {out_csv}")


def cmdbenchmark(args) -> None:
    datasets = [s.strip() for s in args.datasets.split(",") if s.strip()]
    modes = [s.strip() for s in args.modes.split(",") if s.strip()]
    seeds = [args.seed + i for i in range(args.seeds)]

    ensure_dir(args.out)
    summary_rows = []
    print(f"[benchmark] datasets={datasets} modes={modes} seeds={len(seeds)}")

    for ds in datasets:
        for mode in modes:
            accs = []
            for sd in seeds:
                # Train quick run
                runout = os.path.join(args.out, f"{sanitizefilename(ds)}{sanitizefilename(mode)}seed{sd}{now_ts()}")
                os.makedirs(runout, existok=True)
                train_args = argparse.Namespace(
                    dataset=ds,
                    dims=args.dims,
                    classes=args.classes,
                    samples=args.samples,
                    noise=args.noise,
                    valfrac=args.valfrac,
                    mode=mode,
                    epochs=args.epochs,
                    batchsize=args.batchsize,
                    lr=args.lr,
                    weightdecay=args.weightdecay,
                    width=args.width,
                    depth=args.depth,
                    dropout=args.dropout,
                    augnoise=args.augnoise,
                    labeledfrac=args.labeledfrac,
                    pseudothresh=args.pseudothresh,
                    warmupepochs=args.warmupepochs,
                    lambdaconsistency=args.lambdaconsistency,
                    plot=False,
                    seed=sd,
                    cpu=args.cpu,
                    out=run_out,
                )
                # Train
                cmdtrain(train_args)

                # Find checkpoint saved by train
                ckpts = [f for f in os.listdir(run_out) if f.endswith(".pt")]
                if not ckpts:
                    print(f"[benchmark][warn] no checkpoint in {run_out}")
                    continue
                ckptpath = os.path.join(runout, sorted(ckpts)[-1])

                # Eval
                eval_args = argparse.Namespace(
                    ckpt=ckpt_path,
                    samples=args.samples,
                    plot=False,
                    cpu=args.cpu,
                    out=run_out
                )
                # Capture eval output (accuracy)
                device = get_device(args.cpu)
                model, ckpt = loadmodelfromckpt(evalargs.ckpt, device)
                dcfg = ckpt["data_cfg"]
                test_cfg = DataConfig(
                    dataset=dcfg["dataset"],
                    dims=int(dcfg["dims"]),
                    classes=int(dcfg["classes"]),
                    samples=int(eval_args.samples),
                    noise=float(dcfg["noise"]),
                    valfrac=0.0,
                    seed=int(dcfg.get("seed", 42)) + 997,
                )
                , , Xte, yte = gensynthetic(testcfg)
                testloader = DataLoader(TensorDataset(torch.fromnumpy(Xte), torch.fromnumpy(yte)), batchsize=256)
                acc = evaluateaccuracy(model, testloader, device)
                accs.append(acc)

            if accs:
                mean = float(np.mean(accs))
                std = float(np.std(accs))
            else:
                mean, std = 0.0, 0.0
            summary_rows.append({
                "dataset": ds,
                "mode": mode,
                "mean_acc": mean,
                "std_acc": std,
                "runs": len(accs),
            })
            print(f"[benchmark] {ds:>12} | {mode:>9} -> mean_acc={mean:.4f} ± {std:.4f} over {len(accs)} runs")

    # Save CSV
    csvpath = os.path.join(args.out, f"benchmark{now_ts()}.csv")
    import csv
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["dataset", "mode", "meanacc", "stdacc", "runs"])
        for row in summary_rows:
            writer.writerow([row["dataset"], row["mode"], f"{row['meanacc']:.6f}", f"{row['stdacc']:.6f}", row["runs"]])
    print(f"[benchmark] saved summary CSV: {csv_path}")


def cmdexport_onnx(args) -> None:
    device = get_device(args.cpu)
    model, ckpt = loadmodelfrom_ckpt(args.ckpt, device)
    mcfg = ckpt["model"]
    indim = int(mcfg["indim"])
    dummy = torch.zeros(1, in_dim, dtype=torch.float32)

    # Export on CPU for portability
    model_cpu = model.to("cpu").eval()
    torch.onnx.export(
        model_cpu,
        dummy,
        args.out,
        input_names=["input"],
        output_names=["logits"],
        dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
        opset_version=13,
    )
    print(f"[export-onnx] saved ONNX to {args.out}")

# ----------------------------
# Argparse
# ----------------------------

def build_parser():
    p = argparse.ArgumentParser(
        description="MetaIntelligence end-to-end toolkit (notebook/Colab-safe)"
    )
    sub = p.add_subparsers(dest="cmd", required=True)

    # Train
    t = sub.add_parser("train", help="Train a model (self/supervised/hybrid)")
    t.add_argument("--dataset", type=str, default="moons", choices=["moons", "circles", "blobs", "classification"])
    t.add_argument("--dims", type=int, default=2)
    t.add_argument("--classes", type=int, default=2)
    t.add_argument("--samples", type=int, default=2000)
    t.add_argument("--noise", type=float, default=0.2)
    t.add_argument("--val-frac", dest="valfrac", type=float, default=0.2)

    t.add_argument("--mode", type=str, default="supervised", choices=["supervised", "selfsup", "hybrid"])
    t.add_argument("--epochs", type=int, default=50)
    t.add_argument("--batch-size", type=int, default=128)
    t.add_argument("--lr", type=float, default=3e-3)
    t.add_argument("--weight-decay", type=float, default=0.0)
    t.add_argument("--width", type=int, default=64)
    t.add_argument("--depth", type=int, default=2)
    t.add_argument("--dropout", type=float, default=0.0)
    t.add_argument("--aug-noise", type=float, default=0.05, help="Gaussian feature noise for augmentation")

    # self-training/hybrid knobs
    t.add_argument("--labeled-frac", type=float, default=0.1)
    t.add_argument("--pseudo-thresh", type=float, default=0.9)
    t.add_argument("--warmup-epochs", type=int, default=5)
    t.add_argument("--lambda-consistency", type=float, default=0.5)

    t.add_argument("--plot", action="store_true")
    t.add_argument("--seed", type=int, default=42)
    t.add_argument("--cpu", action="store_true")
    t.add_argument("--out", type=str, required=True)

    # Eval
    e = sub.add_parser("eval", help="Evaluate a saved checkpoint on a fresh synthetic test set")
    e.add_argument("--ckpt", type=str, required=True)
    e.add_argument("--samples", type=int, default=2000)
    e.add_argument("--plot", action="store_true")
    e.add_argument("--cpu", action="store_true")
    e.add_argument("--out", type=str, default="")

    # Predict
    pr = sub.add_parser("predict", help="Predict for given points")
    pr.add_argument("--ckpt", type=str, required=True)
    pr.add_argument("--points", type=str, default="", help='Inline points: "x1,x2; y1,y2; ..."')
    pr.add_argument("--csv", type=str, default="", help="CSV file with samples (rows) and features (columns)")
    pr.add_argument("--plot", action="store_true")
    pr.add_argument("--cpu", action="store_true")
    pr.add_argument("--out", type=str, default="")

    # Benchmark
    b = sub.add_parser("benchmark", help="Run grid of modes/datasets/dims across seeds and summarize")
    b.add_argument("--datasets", type=str, default="moons,blobs")
    b.add_argument("--modes", type=str, default="supervised,selfsup,hybrid")
    b.add_argument("--dims", type=int, default=2)
    b.add_argument("--classes", type=int, default=2)
    b.add_argument("--samples", type=int, default=2000)
    b.add_argument("--noise", type=float, default=0.2)
    b.add_argument("--val-frac", dest="valfrac", type=float, default=0.2)
    b.add_argument("--epochs", type=int, default=40)
    b.add_argument("--batch-size", type=int, default=128)
    b.add_argument("--lr", type=float, default=3e-3)
    b.add_argument("--weight-decay", type=float, default=0.0)
    b.add_argument("--width", type=int, default=64)
    b.add_argument("--depth", type=int, default=2)
    b.add_argument("--dropout", type=float, default=0.0)
    b.add_argument("--aug-noise", type=float, default=0.05)
    b.add_argument("--labeled-frac", type=float, default=0.1)
    b.add_argument("--pseudo-thresh", type=float, default=0.9)
    b.add_argument("--warmup-epochs", type=int, default=5)
    b.add_argument("--lambda-consistency", type=float, default=0.5)
    b.add_argument("--seeds", type=int, default=3, help="number of seeds to run starting from --seed")
    b.add_argument("--seed", type=int, default=42)
    b.add_argument("--cpu", action="store_true")
    b.add_argument("--out", type=str, required=True)

    # Export ONNX
    x = sub.add_parser("export-onnx", help="Export a trained checkpoint to ONNX")
    x.add_argument("--ckpt", type=str, required=True)
    x.add_argument("--out", type=str, required=True)
    x.add_argument("--cpu", action="store_true")

    return p


def main(argv: Optional[List[str]] = None):
    parser = build_parser()

    # Notebook/Colab-safe: running with no args prints help and returns
    if argv is None:
        argv = []
    if len(argv) == 0:
        parser.print_help()
        return

    args = parser.parse_args(argv)

    # Basic sanitization for file paths
    for attr in ["ckpt", "csv", "out"]:
        if hasattr(args, attr):
            val = getattr(args, attr)
            if isinstance(val, str) and any(ch in val for ch in ["..", "|", ";", "`"]):
                raise ValueError(f"Unsafe characters in path argument: --{attr}")

    # Dispatch after handlers are defined
    dispatch = {
        "train": cmdtrain,
        "eval": cmdeval,
        "predict": cmdpredict,
        "benchmark": cmdbenchmark,
        "export-onnx": cmdexport_onnx,
    }
    handler = dispatch.get(args.cmd)
    if handler is None:
        parser.error(f"Unknown command: {args.cmd}")
    return handler(args)


if __name__ == "__main__":
    # Safe for notebooks and scripts
    main()