<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/colab_kernel_launcher_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# colab_kernel_launcher.py
# MetaIntelligence end-to-end toolkit (notebook/Colab-safe)

import os
import sys
import json
import math
import time
import copy
import argparse
import random
import csv
from typing import Tuple, Dict, Any, List, Optional

import numpy as np

# Optional heavy deps guarded to keep import-time light in notebooks
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader, TensorDataset
except Exception as e:
    raise RuntimeError("PyTorch is required. Please install torch before running this script.") from e

# sklearn only for data generation and metrics
try:
    from sklearn.datasets import make_blobs, make_moons, make_circles, make_classification
    from sklearn.metrics import confusion_matrix, accuracy_score
except Exception as e:
    raise RuntimeError("scikit-learn is required. Please install scikit-learn before running this script.") from e

# Matplotlib is optional: plotting will be skipped if unavailable
try:
    import matplotlib.pyplot as plt
    HAS_MPL = True
except Exception:
    HAS_MPL = False


# ----------------------------
# Utilities
# ----------------------------
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def ensure_dir(path: str) -> None:
    if path and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)


def save_json(obj: Dict[str, Any], path: str) -> None:
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, sort_keys=True)


def load_json(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def save_csv(rows: List[Dict[str, Any]], path: str) -> None:
    if not rows:
        return
    ensure_dir(os.path.dirname(path))
    fieldnames = list(rows[0].keys())
    with open(path, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for r in rows:
            writer.writerow(r)


class Standardizer:
    def __init__(self):
        self.mean_: Optional[np.ndarray] = None
        self.scale_: Optional[np.ndarray] = None

    def fit(self, X: np.ndarray):
        self.mean_ = X.mean(axis=0)
        self.scale_ = X.std(axis=0)
        self.scale_[self.scale_ == 0.0] = 1.0
        return self

    def transform(self, X: np.ndarray) -> np.ndarray:
        if self.mean_ is None or self.scale_ is None:
            raise ValueError("Standardizer not fitted")
        return (X - self.mean_) / self.scale_

    def to_dict(self) -> Dict[str, Any]:
        return {
            "mean": self.mean_.tolist() if self.mean_ is not None else None,
            "scale": self.scale_.tolist() if self.scale_ is not None else None,
        }

    @staticmethod
    def from_dict(d: Dict[str, Any]) -> "Standardizer":
        s = Standardizer()
        if d.get("mean") is not None:
            s.mean_ = np.array(d["mean"], dtype=np.float32)
        if d.get("scale") is not None:
            s.scale_ = np.array(d["scale"], dtype=np.float32)
        return s


# ----------------------------
# Data generation
# ----------------------------
def generate_synthetic(
    dataset: str,
    n_samples: int,
    n_features: int,
    noise: float,
    n_classes: int,
    seed: int,
) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.RandomState(seed)
    if dataset == "blobs":
        centers = n_classes
        X, y = make_blobs(n_samples=n_samples, centers=centers, n_features=n_features,
                          cluster_std=max(0.05, noise), random_state=seed)
    elif dataset == "moons":
        if n_features != 2:
            raise ValueError("moons supports only dims=2")
        X, y = make_moons(n_samples=n_samples, noise=noise, random_state=seed)
    elif dataset == "circles":
        if n_features != 2:
            raise ValueError("circles supports only dims=2")
        X, y = make_circles(n_samples=n_samples, factor=0.5, noise=noise, random_state=seed)
    elif dataset == "classification":
        X, y = make_classification(
            n_samples=n_samples, n_features=n_features,
            n_informative=max(2, min(n_features, n_features - 2)),
            n_redundant=0, n_repeated=0,
            n_classes=n_classes, n_clusters_per_class=1,
            flip_y=min(0.25, noise), class_sep=max(0.5, 2 - 3*noise),
            random_state=seed
        )
    else:
        raise ValueError(f"Unknown dataset: {dataset}")
    X = X.astype(np.float32)
    y = y.astype(np.int64)
    # Shuffle
    idx = np.arange(len(X))
    rng.shuffle(idx)
    return X[idx], y[idx]


def train_val_split(X, y, val_frac=0.2, seed=42):
    n = len(X)
    n_val = int(n * val_frac)
    rng = np.random.RandomState(seed)
    idx = np.arange(n)
    rng.shuffle(idx)
    val_idx = idx[:n_val]
    train_idx = idx[n_val:]
    return (X[train_idx], y[train_idx]), (X[val_idx], y[val_idx])


# ----------------------------
# Model
# ----------------------------
class MLP(nn.Module):
    def __init__(self, input_dim: int, hidden: List[int], num_classes: int, dropout: float = 0.0):
        super().__init__()
        layers = []
        last = input_dim
        for h in hidden:
            layers += [nn.Linear(last, h), nn.ReLU(inplace=True)]
            if dropout > 0:
                layers.append(nn.Dropout(p=dropout))
            last = h
        layers.append(nn.Linear(last, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)  # logits


def make_model(input_dim: int, num_classes: int, width: int, depth: int, dropout: float) -> MLP:
    hidden = [width] * max(1, depth)
    return MLP(input_dim, hidden, num_classes, dropout)


# ----------------------------
# Training helpers
# ----------------------------
def batch_iter(X: np.ndarray, y: Optional[np.ndarray], batch_size: int, shuffle: bool, seed: int):
    n = len(X)
    idx = np.arange(n)
    if shuffle:
        rng = np.random.RandomState(seed)
        rng.shuffle(idx)
    for start in range(0, n, batch_size):
        sl = idx[start:start + batch_size]
        if y is None:
            yield X[sl], None
        else:
            yield X[sl], y[sl]


def add_noise(x: torch.Tensor, sigma: float) -> torch.Tensor:
    if sigma <= 0:
        return x
    noise = torch.randn_like(x) * sigma
    return x + noise


def train_supervised(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    Xtr: np.ndarray, ytr: np.ndarray,
    Xva: np.ndarray, yva: np.ndarray,
    device: str,
    epochs: int,
    batch_size: int,
    aug_noise: float = 0.0,
    seed: int = 42,
) -> Dict[str, List[float]]:
    crit = nn.CrossEntropyLoss()
    metrics = {"train_loss": [], "val_acc": []}

    Xtr_t = torch.from_numpy(Xtr).to(device)
    ytr_t = torch.from_numpy(ytr).to(device)
    Xva_t = torch.from_numpy(Xva).to(device)
    yva_t = torch.from_numpy(yva).to(device)

    for ep in range(1, epochs + 1):
        model.train()
        running = 0.0
        for xb_np, yb_np in batch_iter(Xtr, ytr, batch_size, shuffle=True, seed=seed + ep):
            xb = torch.from_numpy(xb_np).to(device)
            yb = torch.from_numpy(yb_np).to(device)
            xb = add_noise(xb, aug_noise)
            logits = model(xb)
            loss = crit(logits, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running += loss.item() * len(xb_np)
        train_loss = running / len(Xtr)
        metrics["train_loss"].append(train_loss)

        # val
        model.eval()
        with torch.no_grad():
            logits = model(Xva_t)
            preds = logits.argmax(dim=1)
            acc = (preds == yva_t).float().mean().item()
        metrics["val_acc"].append(acc)
        print(f"[supervised] epoch {ep:03d} | train_loss={train_loss:.4f} | val_acc={acc:.4f}")
    return metrics


def train_self_training(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    Xtr: np.ndarray, ytr: np.ndarray,
    Xva: np.ndarray, yva: np.ndarray,
    device: str,
    epochs: int,
    batch_size: int,
    labeled_frac: float,
    pseudo_thresh: float,
    aug_noise: float,
    warmup_epochs: int,
    seed: int = 42,
) -> Dict[str, List[float]]:
    """
    Simple self-training:
      1) Split train into labeled/unlabeled by labeled_frac
      2) Warmup on labeled
      3) Iteratively add high-confidence pseudo-labeled samples
    """
    n = len(Xtr)
    n_lab = max(1, int(n * labeled_frac))
    rng = np.random.RandomState(seed)
    idx = np.arange(n)
    rng.shuffle(idx)
    lab_idx = idx[:n_lab]
    unlab_idx = idx[n_lab:]

    X_lab, y_lab = Xtr[lab_idx], ytr[lab_idx]
    X_unl = Xtr[unlab_idx]

    # Warmup
    print(f"[selfsup] warmup on {len(X_lab)} labeled samples")
    _ = train_supervised(
        model, optimizer, X_lab, y_lab, Xva, yva, device,
        epochs=warmup_epochs, batch_size=batch_size, aug_noise=aug_noise, seed=seed
    )

    crit = nn.CrossEntropyLoss()
    metrics = {"train_loss": [], "val_acc": [], "pseudo_added": []}

    for ep in range(1, epochs + 1):
        model.train()
        running = 0.0

        # Pseudo-label selection
        if len(X_unl) > 0:
            with torch.no_grad():
                logits = model(torch.from_numpy(X_unl).to(device))
                probs = F.softmax(logits, dim=1).cpu().numpy()
                conf = probs.max(axis=1)
                plabels = probs.argmax(axis=1)
            keep = conf >= pseudo_thresh
            X_keep = X_unl[keep]
            y_keep = plabels[keep]
            if len(X_keep) > 0:
                # merge into labeled pool
                X_lab = np.concatenate([X_lab, X_keep], axis=0)
                y_lab = np.concatenate([y_lab, y_keep], axis=0)
                X_unl = X_unl[~keep]
            added = int(keep.sum())
        else:
            added = 0

        metrics["pseudo_added"].append(added)
        print(f"[selfsup] epoch {ep:03d} | added {added} pseudo-labeled samples | lab_pool={len(X_lab)} | unl={len(X_unl)}")

        # Train one epoch on current labeled pool
        for xb_np, yb_np in batch_iter(X_lab, y_lab, batch_size, shuffle=True, seed=seed + ep):
            xb = torch.from_numpy(xb_np).to(device)
            yb = torch.from_numpy(yb_np).to(device)
            xb = add_noise(xb, aug_noise)
            logits = model(xb)
            loss = crit(logits, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running += loss.item() * len(xb_np)
        train_loss = running / max(1, len(X_lab))
        metrics["train_loss"].append(train_loss)

        # val
        model.eval()
        with torch.no_grad():
            logits = model(torch.from_numpy(Xva).to(device))
            preds = logits.argmax(dim=1).cpu().numpy()
            acc = (preds == yva).mean().item()
        metrics["val_acc"].append(acc)
        print(f"[selfsup] epoch {ep:03d} | train_loss={train_loss:.4f} | val_acc={acc:.4f}")
    return metrics


def train_hybrid_consistency(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    Xtr: np.ndarray, ytr: np.ndarray,
    Xva: np.ndarray, yva: np.ndarray,
    device: str,
    epochs: int,
    batch_size: int,
    aug_noise: float,
    lambda_consistency: float,
    seed: int = 42,
) -> Dict[str, List[float]]:
    """
    Hybrid = supervised CE + consistency regularization between two noisy views.
    """
    ce = nn.CrossEntropyLoss()
    metrics = {"train_loss": [], "val_acc": []}
    Xtr_t = torch.from_numpy(Xtr).to(device)
    ytr_t = torch.from_numpy(ytr).to(device)
    Xva_t = torch.from_numpy(Xva).to(device)
    yva_t = torch.from_numpy(yva).to(device)

    for ep in range(1, epochs + 1):
        model.train()
        running = 0.0
        for xb_np, yb_np in batch_iter(Xtr, ytr, batch_size, shuffle=True, seed=seed + ep):
            xb = torch.from_numpy(xb_np).to(device)
            yb = torch.from_numpy(yb_np).to(device)

            x1 = add_noise(xb, aug_noise)
            x2 = add_noise(xb, aug_noise)

            logits1 = model(x1)
            logits2 = model(x2)

            loss_ce = ce(logits1, yb)
            p1 = F.log_softmax(logits1, dim=1)
            p2 = F.softmax(logits2, dim=1)
            # KL divergence between the two views (p2 as target)
            loss_cons = F.kl_div(p1, p2, reduction="batchmean")
            loss = loss_ce + lambda_consistency * loss_cons

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running += loss.item() * len(xb_np)
        train_loss = running / len(Xtr)
        metrics["train_loss"].append(train_loss)

        model.eval()
        with torch.no_grad():
            logits = model(Xva_t)
            preds = logits.argmax(dim=1)
            acc = (preds == yva_t).float().mean().item()
        metrics["val_acc"].append(acc)
        print(f"[hybrid] epoch {ep:03d} | train_loss={train_loss:.4f} | val_acc={acc:.4f}")
    return metrics


# ----------------------------
# Plotting (2D only)
# ----------------------------
def maybe_plot_2d_decision_boundary(
    path: str,
    model: nn.Module,
    scaler: Standardizer,
    X: np.ndarray,
    y: np.ndarray,
    title: str = "",
    device: str = "cpu",
    grid_steps: int = 300
):
    if not HAS_MPL:
        print("Matplotlib not available; skipping plot.")
        return
    if X.shape[1] != 2:
        print("Plotting only for 2D; skipping plot.")
        return

    x_min, x_max = X[:, 0].min() - 0.6, X[:, 0].max() + 0.6
    y_min, y_max = X[:, 1].min() - 0.6, X[:, 1].max() + 0.6
    xx, yy = np.meshgrid(
        np.linspace(x_min, x_max, grid_steps, dtype=np.float32),
        np.linspace(y_min, y_max, grid_steps, dtype=np.float32)
    )
    grid = np.c_[xx.ravel(), yy.ravel()].astype(np.float32)
    grid_std = scaler.transform(grid)

    model.eval()
    with torch.no_grad():
        logits = model(torch.from_numpy(grid_std).to(device))
        preds = logits.argmax(dim=1).cpu().numpy().reshape(xx.shape)

    plt.figure(figsize=(6.0, 5.5), dpi=140)
    plt.contourf(xx, yy, preds, alpha=0.25, levels=np.arange(preds.max()+2)-0.5, cmap="coolwarm")
    scatter = plt.scatter(X[:, 0], X[:, 1], c=y, s=15, cmap="coolwarm", edgecolors="k", linewidths=0.3)
    plt.title(title)
    plt.xlabel("x1"); plt.ylabel("x2")
    plt.tight_layout()
    ensure_dir(os.path.dirname(path))
    plt.savefig(path)
    plt.close()
    print(f"Saved plot: {path}")


# ----------------------------
# Checkpoint I/O
# ----------------------------
def save_checkpoint(path: str, model: nn.Module, scaler: Standardizer, meta: Dict[str, Any]):
    ckpt = {
        "model_state": model.state_dict(),
        "meta": meta,
        "scaler": scaler.to_dict(),
    }
    ensure_dir(os.path.dirname(path))
    torch.save(ckpt, path)
    print(f"Saved checkpoint: {path}")


def load_checkpoint(path: str, map_location: str = "cpu"):
    ckpt = torch.load(path, map_location=map_location)
    meta = ckpt["meta"]
    scaler = Standardizer.from_dict(ckpt["scaler"])
    model = make_model(
        input_dim=meta["input_dim"],
        num_classes=meta["num_classes"],
        width=meta["model"]["width"],
        depth=meta["model"]["depth"],
        dropout=meta["model"]["dropout"],
    )
    model.load_state_dict(ckpt["model_state"])
    model.to(map_location)
    model.eval()
    return model, scaler, meta


# ----------------------------
# Command implementations
# ----------------------------
def cmd_train(args):
    set_seed(args.seed)
    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
    ensure_dir(args.out)

    # Data
    X_all, y_all = generate_synthetic(
        dataset=args.dataset,
        n_samples=args.samples,
        n_features=args.dims,
        noise=args.noise,
        n_classes=args.classes,
        seed=args.seed
    )
    (Xtr, ytr), (Xva, yva) = train_val_split(X_all, y_all, val_frac=args.val_frac, seed=args.seed + 1)

    # Scale
    scaler = Standardizer().fit(Xtr)
    Xtr_s = scaler.transform(Xtr)
    Xva_s = scaler.transform(Xva)

    # Model
    model = make_model(input_dim=args.dims, num_classes=args.classes, width=args.width, depth=args.depth, dropout=args.dropout).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # Train
    t0 = time.time()
    if args.mode == "supervised":
        hist = train_supervised(model, optimizer, Xtr_s, ytr, Xva_s, yva, device, args.epochs, args.batch_size, args.aug_noise, args.seed)
    elif args.mode == "selfsup":
        hist = train_self_training(
            model, optimizer, Xtr_s, ytr, Xva_s, yva, device,
            epochs=args.epochs, batch_size=args.batch_size,
            labeled_frac=args.labeled_frac, pseudo_thresh=args.pseudo_thresh,
            aug_noise=args.aug_noise, warmup_epochs=args.warmup_epochs, seed=args.seed
        )
    elif args.mode == "hybrid":
        hist = train_hybrid_consistency(
            model, optimizer, Xtr_s, ytr, Xva_s, yva, device,
            epochs=args.epochs, batch_size=args.batch_size,
            aug_noise=args.aug_noise, lambda_consistency=args.lambda_consistency, seed=args.seed
        )
    else:
        raise ValueError(f"Unknown mode: {args.mode}")
    dt = time.time() - t0

    # Final val accuracy
    model.eval()
    with torch.no_grad():
        logits = model(torch.from_numpy(Xva_s).to(device))
        acc = (logits.argmax(dim=1).cpu().numpy() == yva).mean().item()

    # Save artifacts
    meta = {
        "dataset": args.dataset,
        "input_dim": args.dims,
        "num_classes": args.classes,
        "noise": args.noise,
        "mode": args.mode,
        "val_frac": args.val_frac,
        "seed": args.seed,
        "model": {"width": args.width, "depth": args.depth, "dropout": args.dropout},
        "opt": {"lr": args.lr, "weight_decay": args.weight_decay},
        "train": {
            "epochs": args.epochs, "batch_size": args.batch_size,
            "aug_noise": args.aug_noise,
            "labeled_frac": args.labeled_frac,
            "pseudo_thresh": args.pseudo_thresh,
            "warmup_epochs": args.warmup_epochs,
            "lambda_consistency": args.lambda_consistency
        },
        "val_acc": acc,
        "elapsed_sec": dt
    }

    save_json(meta, os.path.join(args.out, "config.json"))
    save_checkpoint(os.path.join(args.out, "checkpoint.pt"), model, scaler, meta)

    # Plot 2D boundary
    if args.plot and Xtr.shape[1] == 2:
        maybe_plot_2d_decision_boundary(
            path=os.path.join(args.out, "decision_boundary.png"),
            model=model, scaler=scaler, X=Xtr, y=ytr, title=f"{args.dataset} ({args.mode})", device=device
        )

    # Metrics CSV (history)
    rows = []
    epochs_hist = len(next(iter(hist.values()))) if hist else 0
    for i in range(epochs_hist):
        row = {"epoch": i + 1}
        for k, v in hist.items():
            row[k] = v[i] if i < len(v) else ""
        rows.append(row)
    save_csv(rows, os.path.join(args.out, "history.csv"))

    print(f"Training done. Val acc={acc:.4f}. Artifacts in: {args.out}")


def cmd_eval(args):
    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
    model, scaler, meta = load_checkpoint(args.ckpt, map_location=device)

    # Generate a fresh test set with same dataset/config
    Xte, yte = generate_synthetic(
        dataset=meta["dataset"],
        n_samples=args.samples,
        n_features=meta["input_dim"],
        noise=meta["noise"],
        n_classes=meta["num_classes"],
        seed=meta["seed"] + 999  # distinct
    )
    Xte_s = scaler.transform(Xte)

    with torch.no_grad():
        logits = model(torch.from_numpy(Xte_s).to(device))
        preds = logits.argmax(dim=1).cpu().numpy()
    acc = accuracy_score(yte, preds)
    cm = confusion_matrix(yte, preds)

    print(f"Eval accuracy: {acc:.4f}")
    out_dir = args.out or os.path.dirname(args.ckpt)
    ensuredir(outdir)

    # Save metrics CSV
    rows = [{"metric": "accuracy", "value": acc}]
    savecsv(rows, os.path.join(outdir, "eval_metrics.csv"))

    # Save confusion matrix CSV
    cm_rows = []
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            cm_rows.append({"row": i, "col": j, "count": int(cm[i, j])})
    savecsv(cmrows, os.path.join(outdir, "confusionmatrix.csv"))

    # Optional plot (2D only)
    if args.plot and meta["input_dim"] == 2:
        maybeplot2ddecisionboundary(
            path=os.path.join(outdir, "evaldecision_boundary.png"),
            model=model, scaler=scaler, X=Xte, y=yte,
            title=f"{meta['dataset']} (eval)", device=device
        )


def loadcsv_points(path: str) -> np.ndarray:
    # Assumes numerical CSV with no header or any header; we try to parse gracefully.
    data = []
    with open(path, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        for row in reader:
            # skip empty lines or non-numeric headers
            try:
                vals = [float(x) for x in row if x.strip() != ""]
                if vals:
                    data.append(vals)
            except ValueError:
                continue
    if not data:
        raise ValueError(f"No numeric rows found in CSV: {path}")
    # Ensure consistent dims
    dims = len(data[0])
    for r in data:
        if len(r) != dims:
            raise ValueError("Inconsistent number of columns in CSV rows")
    return np.array(data, dtype=np.float32)


def cmd_predict(args):
    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
    model, scaler, meta = loadcheckpoint(args.ckpt, maplocation=device)
    d = meta["input_dim"]

    X: Optional[np.ndarray] = None
    if args.points:
        pts = []
        for pair in args.points.split(";"):
            pair = pair.strip()
            if not pair:
                continue
            nums = [float(x.strip()) for x in pair.split(",")]
            pts.append(nums)
        X = np.array(pts, dtype=np.float32)
    elif args.csv:
        X = loadcsv_points(args.csv)
    else:
        raise ValueError("Provide --points \"x1,x2; ...\" or --csv path.csv")

    if X.shape[1] != d:
        raise ValueError(f"Input dims mismatch: model expects {d}, got {X.shape[1]}")

    Xs = scaler.transform(X)
    with torch.no_grad():
        logits = model(torch.from_numpy(Xs).to(device))
        probs = F.softmax(logits, dim=1).cpu().numpy()
        preds = probs.argmax(axis=1)

    # Save predictions
    out_dir = args.out or os.path.dirname(args.ckpt)
    ensuredir(outdir)
    rows = []
    for i, (p, pr) in enumerate(zip(preds, probs)):
        row = {"index": i, "pred": int(p)}
        for k, v in enumerate(pr):
            row[f"prob_{k}"] = float(v)
        rows.append(row)
    savecsv(rows, os.path.join(outdir, "predictions.csv"))
    print(f"Saved predictions to {os.path.join(out_dir, 'predictions.csv')}")

    # Optional visualization if 2D
    if args.plot and d == 2:
        # Build a small plot with the points overlaid on decision boundary
        # For boundary, generate a synthetic cloud matching config just to get bounds
        Xref, yref = generate_synthetic(
            dataset=meta["dataset"],
            n_samples=500,
            n_features=d,
            noise=meta["noise"],
            nclasses=meta["numclasses"],
            seed=meta["seed"] + 202
        )
        maybeplot2ddecisionboundary(
            path=os.path.join(outdir, "predictboundary.png"),
            model=model, scaler=scaler, X=Xref, y=yref,
            title=f"{meta['dataset']} (predict overlay)", device=device
        )
        # also scatter user points
        if HAS_MPL:
            plt.figure(figsize=(5.5, 4.8), dpi=140)
            plt.scatter(X[:, 0], X[:, 1], c=preds, s=50, cmap="coolwarm", edgecolors="k")
            plt.title("Predicted points")
            plt.tight_layout()
            plt.savefig(os.path.join(outdir, "predictpoints.png"))
            plt.close()
            print(f"Saved plot: {os.path.join(outdir, 'predictpoints.png')}")


def cmd_benchmark(args):
    set_seed(args.seed)
    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
    ensure_dir(args.out)

    datasets = [x.strip() for x in args.datasets.split(",")]
    modes = [x.strip() for x in args.modes.split(",")]
    seeds = list(range(args.seed, args.seed + args.seeds))

    results = []
    total = len(datasets) * len(modes) * len(seeds)
    k = 0

    for ds in datasets:
        for mode in modes:
            for sd in seeds:
                k += 1
                tag = f"{ds}-{mode}-d{args.dims}-s{sd}"
                out_dir = os.path.join(args.out, tag)
                print(f"[{k}/{total}] Running {tag}")
                # Data
                Xall, yall = generatesynthetic(ds, nsamples=args.samples, n_features=args.dims,
                                                  noise=args.noise, n_classes=args.classes, seed=sd)
                (Xtr, ytr), (Xva, yva) = trainvalsplit(Xall, yall, valfrac=args.valfrac, seed=sd + 1)
                scaler = Standardizer().fit(Xtr)
                Xtr_s = scaler.transform(Xtr)
                Xva_s = scaler.transform(Xva)

                model = make_model(args.dims, args.classes, args.width, args.depth, args.dropout).to(device)
                opt = torch.optim.Adam(model.parameters(), lr=args.lr, weightdecay=args.weightdecay)

                if mode == "supervised":
                    trainsupervised(model, opt, Xtrs, ytr, Xvas, yva, device, args.epochs, args.batchsize, args.aug_noise, sd)
                elif mode == "selfsup":
                    trainselftraining(model, opt, Xtrs, ytr, Xvas, yva, device, args.epochs, args.batch_size,
                                        labeledfrac=args.labeledfrac, pseudothresh=args.pseudothresh,
                                        augnoise=args.augnoise, warmupepochs=args.warmupepochs, seed=sd)
                elif mode == "hybrid":
                    trainhybridconsistency(model, opt, Xtrs, ytr, Xvas, yva, device, args.epochs, args.batch_size,
                                             augnoise=args.augnoise, lambdaconsistency=args.lambdaconsistency, seed=sd)
                else:
                    raise ValueError(f"Unknown mode: {mode}")

                with torch.no_grad():
                    logits = model(torch.fromnumpy(Xvas).to(device))
                    preds = logits.argmax(dim=1).cpu().numpy()
                acc = accuracy_score(yva, preds)

                results.append({
                    "dataset": ds, "mode": mode, "dims": args.dims, "seed": sd,
                    "val_acc": acc
                })

                # Save minimal artifacts per run
                ensuredir(outdir)
                savecsv([results[-1]], os.path.join(outdir, "val_acc.csv"))

    # Summarize
    savecsv(results, os.path.join(args.out, "benchmarkresults.csv"))
    # Aggregate by (dataset, mode)
    agg = {}
    for r in results:
        key = (r["dataset"], r["mode"])
        agg.setdefault(key, []).append(r["val_acc"])
    summary_rows = []
    for (ds, mode), vals in agg.items():
        arr = np.array(vals, dtype=np.float32)
        summary_rows.append({
            "dataset": ds,
            "mode": mode,
            "dims": args.dims,
            "runs": len(vals),
            "acc_mean": float(arr.mean()),
            "acc_std": float(arr.std(ddof=1)) if len(vals) > 1 else 0.0,
            "acc_min": float(arr.min()),
            "acc_max": float(arr.max()),
        })
    savecsv(summaryrows, os.path.join(args.out, "benchmark_summary.csv"))
    print(f"Benchmark complete. Summary saved to {os.path.join(args.out, 'benchmark_summary.csv')}")


def cmdexportonnx(args):
    try:
        import onnx  # noqa: F401
    except Exception:
        print("onnx package not found. Please install onnx to export. Skipping.")
        return

    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
    model, scaler, meta = loadcheckpoint(args.ckpt, maplocation=device)
    model.eval()

    dummy = torch.randn(1, meta["input_dim"], device=device)
    dynamic_axes = {"input": {0: "batch"}, "logits": {0: "batch"}}
    ensure_dir(os.path.dirname(args.out))
    torch.onnx.export(
        model, dummy, args.out,
        inputnames=["input"], outputnames=["logits"],
        opsetversion=13, dynamicaxes=dynamic_axes
    )
    print(f"Exported ONNX model to: {args.out}")

    # Save scaler alongside as JSON (needed to standardize inputs pre-ONNX)
    savejson({"standardizer": scaler.todict(), "meta": meta}, os.path.splitext(args.out)[0] + ".preproc.json")


# ----------------------------
# Argparse
# ----------------------------
import argparse
from typing import Optional, List

def build_parser():
    p = argparse.ArgumentParser(
        description="MetaIntelligence end-to-end toolkit (notebook/Colab-safe)"
    )
    sub = p.add_subparsers(dest="cmd", required=True)

    # Train
    t = sub.add_parser("train", help="Train a model (self/supervised/hybrid)")
    t.add_argument("--dataset", type=str, default="moons",
                   choices=["moons", "circles", "blobs", "classification"])
    t.add_argument("--dims", type=int, default=2)
    t.add_argument("--classes", type=int, default=2)
    t.add_argument("--samples", type=int, default=2000)
    t.add_argument("--noise", type=float, default=0.2)
    t.add_argument("--val-frac", dest="valfrac", type=float, default=0.2)

    t.add_argument("--mode", type=str, default="supervised",
                   choices=["supervised", "selfsup", "hybrid"])
    t.add_argument("--epochs", type=int, default=50)
    t.add_argument("--batch-size", type=int, default=128)
    t.add_argument("--lr", type=float, default=3e-3)
    t.add_argument("--weight-decay", type=float, default=0.0)
    t.add_argument("--width", type=int, default=64)
    t.add_argument("--depth", type=int, default=2)
    t.add_argument("--dropout", type=float, default=0.0)
    t.add_argument("--aug-noise", type=float, default=0.05,
                   help="Gaussian feature noise for augmentation")

    # self-training/hybrid knobs
    t.add_argument("--labeled-frac", type=float, default=0.1)
    t.add_argument("--pseudo-thresh", type=float, default=0.9)
    t.add_argument("--warmup-epochs", type=int, default=5)
    t.add_argument("--lambda-consistency", type=float, default=0.5)

    t.add_argument("--plot", action="store_true")
    t.add_argument("--seed", type=int, default=42)
    t.add_argument("--cpu", action="store_true")
    t.add_argument("--out", type=str, required=True)

    # Eval
    e = sub.add_parser("eval", help="Evaluate a saved checkpoint on a fresh synthetic test set")
    e.add_argument("--ckpt", type=str, required=True)
    e.add_argument("--samples", type=int, default=2000)
    e.add_argument("--plot", action="store_true")
    e.add_argument("--cpu", action="store_true")
    e.add_argument("--out", type=str, default="")

    # Predict
    pr = sub.add_parser("predict", help="Predict for given points")
    pr.add_argument("--ckpt", type=str, required=True)
    pr.add_argument("--points", type=str, default="", help='Inline points: "x1,x2; y1,y2; ..."')
    pr.add_argument("--csv", type=str, default="", help="CSV file with samples (rows) and features (columns)")
    pr.add_argument("--plot", action="store_true")
    pr.add_argument("--cpu", action="store_true")
    pr.add_argument("--out", type=str, default="")

    # Benchmark
    b = sub.add_parser("benchmark", help="Run grid of modes/datasets/dims across seeds and summarize")
    b.add_argument("--datasets", type=str, default="moons,blobs")
    b.add_argument("--modes", type=str, default="supervised,selfsup,hybrid")
    b.add_argument("--dims", type=int, default=2)
    b.add_argument("--classes", type=int, default=2)
    b.add_argument("--samples", type=int, default=2000)
    b.add_argument("--noise", type=float, default=0.2)
    b.add_argument("--val-frac", dest="valfrac", type=float, default=0.2)
    b.add_argument("--epochs", type=int, default=40)
    b.add_argument("--batch-size", type=int, default=128)
    b.add_argument("--lr", type=float, default=3e-3)
    b.add_argument("--weight-decay", type=float, default=0.0)
    b.add_argument("--width", type=int, default=64)
    b.add_argument("--depth", type=int, default=2)
    b.add_argument("--dropout", type=float, default=0.0)
    b.add_argument("--aug-noise", type=float, default=0.05)
    b.add_argument("--labeled-frac", type=float, default=0.1)
    b.add_argument("--pseudo-thresh", type=float, default=0.9)
    b.add_argument("--warmup-epochs", type=int, default=5)
    b.add_argument("--lambda-consistency", type=float, default=0.5)
    b.add_argument("--seeds", type=int, default=3, help="number of seeds to run starting from --seed")
    b.add_argument("--seed", type=int, default=42)
    b.add_argument("--cpu", action="store_true")
    b.add_argument("--out", type=str, required=True)

    # Export ONNX
    x = sub.add_parser("export-onnx", help="Export a trained checkpoint to ONNX")
    x.add_argument("--ckpt", type=str, required=True)
    x.add_argument("--out", type=str, required=True)
    x.add_argument("--cpu", action="store_true")

    return p


def main(argv: Optional[List[str]] = None):
    parser = build_parser()

    # Notebook/Colab-safe: running with no args just prints help and returns
    if argv is None:
        argv = []
    if len(argv) == 0:
        parser.print_help()
        return

    args = parser.parse_args(argv)

    # Basic sanitization for file paths
    for attr in ["ckpt", "csv", "out"]:
        if hasattr(args, attr):
            val = getattr(args, attr)
            if isinstance(val, str) and any(ch in val for ch in ["..", "|", ";", "`"]):
                raise ValueError(f"Unsafe characters in path argument: --{attr}")

    # Dispatch after all handlers are defined/imported
    dispatch = {
        "train": cmdtrain,
        "eval": cmdeval,          # Make sure this exists
        "predict": cmdpredict,
        "benchmark": cmdbenchmark,
        "export-onnx": cmdexport_onnx,
    }

    handler = dispatch.get(args.cmd)
    if handler is None:
        parser.error(f"Unknown command: {args.cmd}")
    return handler(args)


if __name__ == "__main__":
    # Safe for notebooks and scripts
    main()