<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/metaintel_cli_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision torchaudio scikit-learn matplotlib pandas seaborn onnx onnxruntime pyyaml

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
MetaIntelligence CLI: train, eval, predict, benchmark, and ONNX export
Datasets: moons, blobs, circles, iris, mnist
Models: MLP (all), CNN (mnist)
Notebook/Colab-safe: functions and main guard
"""

import os
import sys
import json
import math
import time
import yaml
import argparse
import random
import datetime as dt
from dataclasses import dataclass, asdict
from typing import Tuple, Dict, Any, Optional, List

import numpy as np
import pandas as pd
import matplotlib
# Default to non-interactive backend for CLI; enable live plotting with flag
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
    classification_report,
)
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_moons, make_blobs, make_circles, load_iris

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split

# Optional ONNX runtime
try:
    import onnx
    import onnxruntime as ort
    HAS_ONNX = True
except Exception:
    HAS_ONNX = False

# Optional MNIST
try:
    from torchvision import datasets, transforms
    HAS_TORCHVISION = True
except Exception:
    HAS_TORCHVISION = False

# ---------------------------
# Utilities
# ---------------------------

def set_seed(seed: int, deterministic: bool = True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def now_tag():
    return dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def save_json(path: str, obj: Dict[str, Any]):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, sort_keys=True)

def load_json(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def maybe_load_config(path: Optional[str]) -> Dict[str, Any]:
    if not path:
        return {}
    ext = os.path.splitext(path)[1].lower()
    with open(path, "r", encoding="utf-8") as f:
        if ext in [".yaml", ".yml"]:
            return yaml.safe_load(f)
        elif ext == ".json":
            return json.load(f)
        else:
            raise ValueError(f"Unsupported config extension: {ext}")

def pretty_time(s: float) -> str:
    if s < 60:
        return f"{s:.2f}s"
    m, sec = divmod(s, 60)
    if m < 60:
        return f"{int(m)}m {sec:.1f}s"
    h, m = divmod(m, 60)
    return f"{int(h)}h {int(m)}m {sec:.0f}s"

# ---------------------------
# Data loading
# ---------------------------

@dataclass
class DataBundle:
    X_train: np.ndarray
    y_train: np.ndarray
    X_val: np.ndarray
    y_val: np.ndarray
    X_test: np.ndarray
    y_test: np.ndarray
    input_shape: Tuple[int, ...]
    num_classes: int
    class_names: List[str]
    scaler_mean: Optional[List[float]] = None
    scaler_std: Optional[List[float]] = None
    is_image: bool = False  # True for MNIST

def make_synthetic(
    kind: str,
    n_samples: int = 2000,
    noise: float = 0.2,
    random_state: int = 42,
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    if kind == "moons":
        X, y = make_moons(n_samples=n_samples, noise=noise, random_state=random_state)
        classes = ["class0", "class1"]
    elif kind == "blobs":
        X, y = make_blobs(n_samples=n_samples, centers=3, cluster_std=1.5, random_state=random_state)
        classes = ["c0", "c1", "c2"]
    elif kind == "circles":
        X, y = make_circles(n_samples=n_samples, noise=noise, factor=0.5, random_state=random_state)
        classes = ["inner", "outer"]
    else:
        raise ValueError(f"Unknown synthetic dataset: {kind}")
    return X.astype(np.float32), y.astype(np.int64), classes

def load_iris_dataset() -> Tuple[np.ndarray, np.ndarray, List[str]]:
    iris = load_iris()
    X = iris["data"].astype(np.float32)
    y = iris["target"].astype(np.int64)
    classes = list(iris["target_names"])
    return X, y, classes

def load_mnist_dataset(data_dir: str, download: bool, seed: int) -> Tuple[TensorDataset, TensorDataset, TensorDataset, List[str]]:
    if not HAS_TORCHVISION:
        raise RuntimeError("torchvision not available. Install torchvision to use MNIST.")
    transform = transforms.Compose([transforms.ToTensor()])
    train_full = datasets.MNIST(root=data_dir, train=True, download=download, transform=transform)
    test_ds = datasets.MNIST(root=data_dir, train=False, download=download, transform=transform)

    # Split train into train/val
    val_size = int(0.1 * len(train_full))
    train_size = len(train_full) - val_size
    generator = torch.Generator().manual_seed(seed)
    train_ds, val_ds = random_split(train_full, [train_size, val_size], generator=generator)
    classes = [str(i) for i in range(10)]
    return train_ds, val_ds, test_ds, classes

def build_tabular_bundle(
    X: np.ndarray,
    y: np.ndarray,
    class_names: List[str],
    test_size: float,
    val_size: float,
    seed: int,
    standardize: bool = True,
) -> DataBundle:
    X_train, X_hold, y_train, y_hold = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_hold, y_hold, test_size=0.5, random_state=seed, stratify=y_hold
    )
    scaler_mean = None
    scaler_std = None
    if standardize:
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train).astype(np.float32)
        X_val = scaler.transform(X_val).astype(np.float32)
        X_test = scaler.transform(X_test).astype(np.float32)
        scaler_mean = scaler.mean_.astype(np.float32).tolist()
        scaler_std = scaler.scale_.astype(np.float32).tolist()

    input_shape = (X.shape[1],)
    num_classes = len(np.unique(y))
    return DataBundle(
        X_train, y_train, X_val, y_val, X_test, y_test,
        input_shape=input_shape, num_classes=num_classes,
        class_names=class_names,
        scaler_mean=scaler_mean, scaler_std=scaler_std,
        is_image=False
    )

# ---------------------------
# Models
# ---------------------------

class MLP(nn.Module):
    def __init__(self, in_dim: int, num_classes: int, hidden_sizes: List[int], dropout: float = 0.0, batchnorm: bool = False):
        super().__init__()
        layers = []
        prev = in_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(prev, h))
            if batchnorm:
                layers.append(nn.BatchNorm1d(h))
            layers.append(nn.ReLU())
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            prev = h
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class SimpleCNN(nn.Module):
    # For MNIST 1x28x28
    def __init__(self, num_classes: int = 10, dropout: float = 0.0):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 14x14
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 7x7
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.conv(x)
        return self.fc(x)

# ---------------------------
# Training / Evaluation
# ---------------------------

def make_device(device_str: str) -> torch.device:
    if device_str == "auto":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.device(device_str)

def to_tensor(x: np.ndarray) -> torch.Tensor:
    return torch.from_numpy(x)

def accuracy_from_logits(logits: torch.Tensor, y: torch.Tensor) -> float:
    preds = logits.argmax(dim=1)
    return (preds == y).float().mean().item()

def plot_training_curves(history: pd.DataFrame, out_png: str):
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history["epoch"], history["train_loss"], label="train")
    plt.plot(history["epoch"], history["val_loss"], label="val")
    plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.title("Loss")
    plt.subplot(1, 2, 2)
    plt.plot(history["epoch"], history["train_acc"], label="train")
    plt.plot(history["epoch"], history["val_acc"], label="val")
    plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.legend(); plt.title("Accuracy")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def plot_confusion(cm: np.ndarray, class_names: List[str], out_png: str):
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.ylabel("True"); plt.xlabel("Pred")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def plot_decision_boundary(model: nn.Module, bundle: DataBundle, device: torch.device, out_png: str):
    # Only for 2D tabular
    if bundle.input_shape[0] != 2 or bundle.is_image:
        return
    x_min, x_max = bundle.X_train[:, 0].min() - 0.5, bundle.X_train[:, 0].max() + 0.5
    y_min, y_max = bundle.X_train[:, 1].min() - 0.5, bundle.X_train[:, 1].max() + 0.5
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 300), np.linspace(y_min, y_max, 300))
    grid = np.c_[xx.ravel(), yy.ravel()].astype(np.float32)

    # If standardized, standardize the grid using training stats
    if bundle.scaler_mean is not None and bundle.scaler_std is not None:
        grid = (grid - np.array(bundle.scaler_mean)) / np.array(bundle.scaler_std)

    with torch.no_grad():
        g = to_tensor(grid).to(device)
        logits = model(g)
        preds = logits.argmax(dim=1).cpu().numpy().reshape(xx.shape)

    plt.figure(figsize=(6, 5))
    plt.contourf(xx, yy, preds, alpha=0.25, levels=np.arange(bundle.num_classes + 1) - 0.5, cmap="Set3")
    # Plot training points (de-standardize back for visualization)
    X_plot = bundle.X_train
    if bundle.scaler_mean is not None and bundle.scaler_std is not None:
        X_plot = X_plot * np.array(bundle.scaler_std) + np.array(bundle.scaler_mean)
    scatter = plt.scatter(X_plot[:, 0], X_plot[:, 1], c=bundle.y_train, cmap="Set1", edgecolor="k", s=20)
    plt.title("Decision boundary (train)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def train_one_epoch(model, loader, optimizer, device, loss_fn):
    model.train()
    total_loss = 0.0
    total_acc = 0.0
    count = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad(set_to_none=True)
        out = model(xb)
        loss = loss_fn(out, yb)
        loss.backward()
        optimizer.step()
        acc = accuracy_from_logits(out, yb)
        bsz = yb.size(0)
        total_loss += loss.item() * bsz
        total_acc += acc * bsz
        count += bsz
    return total_loss / count, total_acc / count

@torch.no_grad()
def eval_epoch(model, loader, device, loss_fn):
    model.eval()
    total_loss = 0.0
    total_acc = 0.0
    count = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        out = model(xb)
        loss = loss_fn(out, yb)
        acc = accuracy_from_logits(out, yb)
        bsz = yb.size(0)
        total_loss += loss.item() * bsz
        total_acc += acc * bsz
        count += bsz
    return total_loss / count, total_acc / count

def build_model(model_name: str, input_shape: Tuple[int, ...], num_classes: int, hidden_sizes: List[int], dropout: float, batchnorm: bool):
    if model_name == "mlp":
        return MLP(in_dim=input_shape[0], num_classes=num_classes, hidden_sizes=hidden_sizes, dropout=dropout, batchnorm=batchnorm)
    elif model_name == "cnn":
        return SimpleCNN(num_classes=num_classes, dropout=dropout)
    else:
        raise ValueError(f"Unknown model: {model_name}")

def make_dataloaders(bundle: DataBundle, is_image: bool, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    if not is_image:
        X_tr = torch.from_numpy(bundle.X_train)
        y_tr = torch.from_numpy(bundle.y_train)
        X_va = torch.from_numpy(bundle.X_val)
        y_va = torch.from_numpy(bundle.y_val)
        X_te = torch.from_numpy(bundle.X_test)
        y_te = torch.from_numpy(bundle.y_test)
        train_ds = TensorDataset(X_tr, y_tr)
        val_ds = TensorDataset(X_va, y_va)
        test_ds = TensorDataset(X_te, y_te)
    else:
        # MNIST: already TensorDataset from torchvision splits
        raise RuntimeError("For MNIST, use dedicated make_mnist_loaders.")

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader

def make_mnist_loaders(train_ds, val_ds, test_ds, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader

def collect_preds(model, loader, device) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    all_logits = []
    all_targets = []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            logits = model(xb)
            all_logits.append(logits.cpu().numpy())
            all_targets.append(yb.numpy())
    logits = np.concatenate(all_logits, axis=0)
    targets = np.concatenate(all_targets, axis=0)
    preds = logits.argmax(axis=1)
    return preds, targets

# ---------------------------
# Config dataclass
# ---------------------------

@dataclass
class TrainConfig:
    dataset: str = "moons"        # moons|blobs|circles|iris|mnist
    model: str = "mlp"            # mlp|cnn(mnist only)
    data_dir: str = "./data"      # for mnist
    download: bool = True         # for mnist
    test_size: float = 0.2
    val_frac: float = 0.5         # internal val split from the hold-out; used only for tabular
    batch_size: int = 64
    epochs: int = 50
    lr: float = 1e-3
    weight_decay: float = 0.0
    hidden_sizes: Tuple[int, ...] = (64, 64)
    dropout: float = 0.0
    batchnorm: bool = False
    seed: int = 42
    deterministic: bool = True
    device: str = "auto"          # auto|cpu|cuda
    live_plot: bool = False
    standardize: bool = True
    noise: float = 0.2            # for synthetic
    n_samples: int = 2000         # for synthetic
    # early stopping
    early_stop: bool = True
    patience: int = 10
    min_delta: float = 1e-4

    # misc
    run_dir: Optional[str] = None

    def as_dict(self) -> Dict[str, Any]:
        d = asdict(self)
        d["hidden_sizes"] = list(self.hidden_sizes)
        return d

# ---------------------------
# Live plotting helper
# ---------------------------

class LivePlotter:
    def __init__(self, enable: bool = False):
        self.enable = enable
        if self.enable:
            matplotlib.use("TkAgg")
            plt.ion()
            self.fig, (self.ax1, self.ax2) = plt.subplots(1, 2, figsize=(10, 4))
            self.ax1.set_title("Loss"); self.ax2.set_title("Accuracy")
            self.ax1.set_xlabel("epoch"); self.ax2.set_xlabel("epoch")
            self.ax1.grid(True); self.ax2.grid(True)
            self.train_loss_line, = self.ax1.plot([], [], label="train")
            self.val_loss_line,   = self.ax1.plot([], [], label="val")
            self.ax1.legend()
            self.train_acc_line, = self.ax2.plot([], [], label="train")
            self.val_acc_line,   = self.ax2.plot([], [], label="val")
            self.ax2.legend()
            self.fig.tight_layout()

    def update(self, history: pd.DataFrame):
        if not self.enable:
            return
        x = history["epoch"].values
        self.train_loss_line.set_data(x, history["train_loss"].values)
        self.val_loss_line.set_data(x, history["val_loss"].values)
        self.ax1.relim(); self.ax1.autoscale_view()

        self.train_acc_line.set_data(x, history["train_acc"].values)
        self.val_acc_line.set_data(x, history["val_acc"].values)
        self.ax2.relim(); self.ax2.autoscale_view()

        plt.pause(0.001)

    def close(self):
        if self.enable:
            plt.ioff()
            plt.close(self.fig)

# ---------------------------
# Train command
# ---------------------------

def cmd_train(args: argparse.Namespace):
    # 1) Load defaults and optional config file
    base_cfg = TrainConfig()
    file_cfg = maybe_load_config(getattr(args, "config", None))
    # Merge: defaults <- file_cfg (if any)
    cfg_dict = {**base_cfg.as_dict(), **(file_cfg or {})}

    # 2) CLI overrides (only when provided)
    override_map = {
        "dataset": "dataset",
        "model": "model",
        "data_dir": "data_dir",
        "download": "download",
        "test_size": "test_size",
        "val_frac": "val_frac",
        "batch_size": "batch_size",
        "epochs": "epochs",
        "lr": "lr",
        "weight_decay": "weight_decay",
        "dropout": "dropout",
        "batchnorm": "batchnorm",
        "seed": "seed",
        "deterministic": "deterministic",
        "device": "device",
        "live_plot": "live_plot",
        "standardize": "standardize",
        "noise": "noise",
        "n_samples": "n_samples",
        "early_stop": "early_stop",
        "patience": "patience",
        "min_delta": "min_delta",
        "run_dir": "run_dir",
    }
    for cli_key, cfg_key in override_map.items():
        if hasattr(args, cli_key):
            val = getattr(args, cli_key)
            if val is not None:
                cfg_dict[cfg_key] = val
    # hidden sizes (list)
    if hasattr(args, "hidden_sizes") and args.hidden_sizes:
        cfg_dict["hidden_sizes"] = args.hidden_sizes

    # 3) Build TrainConfig
    cfg = TrainConfig(**cfg_dict)

    # 4) Reproducibility and device
    set_seed(cfg.seed, cfg.deterministic)
    device = make_device(cfg.device)

    # 5) Prepare run directory and save config
    run_tag = f"{now_tag()}_{cfg.dataset}_{cfg.model}"
    run_dir = cfg.run_dir or os.path.join("runs", run_tag)
    ensure_dir(run_dir)
    save_json(os.path.join(run_dir, "config.json"), cfg.as_dict())

    # 6) Load data
    if cfg.dataset in ["moons", "blobs", "circles"]:
        X, y, class_names = make_synthetic(cfg.dataset, n_samples=cfg.n_samples, noise=cfg.noise, random_state=cfg.seed)
        bundle = build_tabular_bundle(
            X, y, class_names,
            test_size=cfg.test_size,
            val_size=cfg.val_frac,
            seed=cfg.seed,
            standardize=cfg.standardize,
        )
        is_image = False
    elif cfg.dataset == "iris":
        X, y, class_names = load_iris_dataset()
        bundle = build_tabular_bundle(
            X, y, class_names,
            test_size=cfg.test_size,
            val_size=cfg.val_frac,
            seed=cfg.seed,
            standardize=cfg.standardize,
        )
        is_image = False
    elif cfg.dataset == "mnist":
        train_ds, val_ds, test_ds, class_names = load_mnist_dataset(cfg.data_dir, cfg.download, cfg.seed)
        is_image = True
    else:
        raise ValueError("Unsupported dataset.")

    # 7) Dataloaders
    if not is_image:
        train_loader, val_loader, test_loader = make_dataloaders(bundle, False, cfg.batch_size)
        input_shape = bundle.input_shape
        num_classes = bundle.num_classes
    else:
        train_loader, val_loader, test_loader = make_mnist_loaders(train_ds, val_ds, test_ds, cfg.batch_size)
        input_shape = (1, 28, 28)
        num_classes = 10

    # 8) Build model, optimizer, loss
    model = build_model(cfg.model, input_shape, num_classes, list(cfg.hidden_sizes), cfg.dropout, cfg.batchnorm).to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    # 9) Training loop with early stopping on val_acc
    best_val_acc = -1.0
    best_state = None
    best_epoch = -1
    epochs_no_improve = 0

    history_rows = []
    live = LivePlotter(enable=bool(cfg.live_plot))
    t0 = time.time()

    for epoch in range(1, cfg.epochs + 1):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device, loss_fn)
        val_loss, val_acc = eval_epoch(model, val_loader, device, loss_fn)

        history_rows.append({
            "epoch": epoch,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "train_acc": train_acc,
            "val_acc": val_acc,
        })
        history = pd.DataFrame(history_rows)
        history.to_csv(os.path.join(run_dir, "metrics.csv"), index=False)
        live.update(history)

        if val_acc > best_val_acc + cfg.min_delta:
            best_val_acc = val_acc
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            best_epoch = epoch
            epochs_no_improve = 0
            torch.save(best_state, os.path.join(run_dir, "best_model.pt"))
        else:
            epochs_no_improve += 1

        print(f"[{epoch:03d}/{cfg.epochs}] train_loss={train_loss:.4f} val_loss={val_loss:.4f} "
              f"train_acc={train_acc:.4f} val_acc={val_acc:.4f}")

        if cfg.early_stop and epochs_no_improve >= cfg.patience:
            print(f"Early stopping at epoch {epoch}. Best val_acc={best_val_acc:.4f} at epoch {best_epoch}.")
            break

    live.close()
    elapsed = time.time() - t0
    print(f"Training finished in {pretty_time(elapsed)}. Best val_acc={best_val_acc:.4f} (epoch {best_epoch}).")

    # 10) Persist meta
    meta = {
        "class_names": class_names,
        "is_image": is_image,
        "input_shape": input_shape,
        "num_classes": num_classes,
        "scaler_mean": getattr(bundle, "scaler_mean", None) if not is_image else None,
        "scaler_std": getattr(bundle, "scaler_std", None) if not is_image else None,
        "best_epoch": best_epoch,
        "elapsed_seconds": elapsed,
    }
    save_json(os.path.join(run_dir, "meta.json"), meta)

    # 11) Static plots
    try:
        plot_training_curves(pd.read_csv(os.path.join(run_dir, "metrics.csv")), os.path.join(run_dir, "training_curves.png"))
    except Exception as e:
        print(f"Training curves plot failed: {e}")

    if not is_image:
        try:
            if bundle.input_shape[0] == 2:
                if best_state is not None:
                    model.load_state_dict(best_state)
                plot_decision_boundary(model.to(device), bundle, device, os.path.join(run_dir, "decision_boundary.png"))
        except Exception as e:
            print(f"Decision boundary plot failed: {e}")

    print(f"Artifacts saved to: {run_dir}")

# ---------------------------
# Eval command
# ---------------------------

def cmd_eval(args: argparse.Namespace):
    run_dir = args.run_dir
    assert os.path.isdir(run_dir), f"Run dir not found: {run_dir}"

    cfg = load_json(os.path.join(run_dir, "config.json"))
    meta = load_json(os.path.join(run_dir, "meta.json"))

    device = make_device(args.device or cfg.get("device", "auto"))
    class_names = meta["class_names"]
    is_image = meta["is_image"]
    num_classes = meta["num_classes"]
    input_shape = tuple(meta["input_shape"])

    # Reconstruct dataset
    if cfg["dataset"] == "mnist":
        if not HAS_TORCHVISION:
            raise RuntimeError("torchvision not available for MNIST eval.")
        train_ds, val_ds, test_ds, _ = load_mnist_dataset(cfg["data_dir"], cfg.get("download", True), cfg.get("seed", 42))
        _, _, test_loader = make_mnist_loaders(train_ds, val_ds, test_ds, cfg["batch_size"])
    else:
        if cfg["dataset"] in ["moons", "blobs", "circles"]:
            X, y, class_names = make_synthetic(cfg["dataset"], n_samples=cfg["n_samples"], noise=cfg["noise"], random_state=cfg["seed"])
        elif cfg["dataset"] == "iris":
            X, y, class_names = load_iris_dataset()
        else:
            raise ValueError(f"Unsupported dataset: {cfg['dataset']}")
        bundle = build_tabular_bundle(
            X, y, class_names,
            test_size=cfg["test_size"],
            val_size=cfg["val_frac"],
            seed=cfg["seed"],
            standardize=cfg["standardize"],
        )
        _, _, test_loader = make_dataloaders(bundle, is_train=False, batch_size=cfg["batch_size"])

    # Build model
    model = build_model(
        cfg["model"],
        input_shape,
        num_classes,
        cfg.get("hidden_sizes", [64, 64]),
        cfg.get("dropout", 0.0),
        cfg.get("batchnorm", False)
    ).to(device)

    # Load weights and prepare loss
    state = torch.load(os.path.join(run_dir, "best_model.pt"), map_location="cpu")
    model.load_state_dict(state)
    loss_fn = nn.CrossEntropyLoss()

    # Evaluate
    test_loss, test_acc = eval_epoch(model, test_loader, device, loss_fn)
    preds, targets = collect_preds(model, test_loader, device)

    # Metrics
    pr, rc, f1 = precision_recall_fscore_support(targets, preds, average="macro", zero_division=0)
    cm = confusion_matrix(targets, preds)

    results = {
        "test_loss": float(test_loss),
        "test_acc": float(test_acc),
        "precision_macro": float(pr),
        "recall_macro": float(rc),
        "f1_macro": float(f1),
    }
    save_json(os.path.join(run_dir, "eval.json"), results)
    print("Evaluation:", results)

    # Plots and report
    plot_confusion(cm, class_names, os.path.join(run_dir, "confusion_matrix.png"))
    report = classification_report(targets, preds, target_names=class_names, digits=4)
    with open(os.path.join(run_dir, "classification_report.txt"), "w", encoding="utf-8") as f:
        f.write(report)
    print("Saved confusion matrix and classification report.")

# ---------------------------
# Predict command
# ---------------------------

def cmd_predict(args: argparse.Namespace):
    run_dir = args.run_dir
    input_csv = args.input_csv
    output_csv = args.output_csv

    assert os.path.isdir(run_dir), f"Run dir not found: {run_dir}"
    assert os.path.isfile(input_csv), f"Input CSV not found: {input_csv}"

    cfg = load_json(os.path.join(run_dir, "config.json"))
    meta = load_json(os.path.join(run_dir, "meta.json"))
    class_names = meta["class_names"]
    is_image = meta["is_image"]
    input_shape = tuple(meta["input_shape"])
    num_classes = meta["num_classes"]

    if is_image:
        raise RuntimeError("Predict from CSV is for tabular models only.")

    device = make_device(args.device or cfg.get("device", "auto"))

    # Load model
    model = build_model(
        cfg["model"],
        input_shape,
        num_classes,
        cfg.get("hidden_sizes", [64, 64]),
        cfg.get("dropout", 0.0),
        cfg.get("batchnorm", False),
    ).to(device)
    state = torch.load(os.path.join(run_dir, "best_model.pt"), map_location="cpu")
    model.load_state_dict(state)
    model.eval()

    # Load features
    df = pd.read_csv(input_csv)
    X = df.values.astype(np.float32)

    # Standardize if training used scaler
    mean = meta.get("scaler_mean", None)
    std = meta.get("scaler_std", None)
    if mean is not None and std is not None:
        X = (X - np.array(mean, dtype=np.float32)) / np.array(std, dtype=np.float32)

    with torch.no_grad():
        logits = model(torch.from_numpy(X).to(device)).cpu().numpy()
        pred_idx = logits.argmax(axis=1)
        pred_label = [class_names[i] for i in pred_idx]
        proba = torch.softmax(torch.from_numpy(logits), dim=1).numpy()

    out = df.copy()
    out["pred_index"] = pred_idx
    out["pred_label"] = pred_label
    out["pred_prob"] = proba.max(axis=1)  # top-1 probability
    out.to_csv(output_csv, index=False)

    print(f"Predictions saved to {output_csv}")

# ---------------------------
# Export ONNX command
# ---------------------------

def cmd_export_onnx(args: argparse.Namespace):
    run_dir = args.run_dir
    assert os.path.isdir(run_dir), f"Run dir not found: {run_dir}"
    cfg = load_json(os.path.join(run_dir, "config.json"))
    meta = load_json(os.path.join(run_dir, "meta.json"))
    input_shape = tuple(meta["input_shape"])
    num_classes = meta["num_classes"]
    is_image = bool(meta["is_image"])

    # Build and load model on CPU for portability
    model = build_model(
        cfg["model"],
        input_shape,
        num_classes,
        cfg.get("hidden_sizes", [64, 64]),
        cfg.get("dropout", 0.0),
        cfg.get("batchnorm", False),
    )
    state = torch.load(os.path.join(run_dir, "best_model.pt"), map_location="cpu")
    model.load_state_dict(state)
    model.eval()

    if is_image:
        dummy = torch.randn(1, *input_shape, dtype=torch.float32)
    else:
        dummy = torch.randn(1, input_shape[0], dtype=torch.float32)

    onnx_path = os.path.join(run_dir, "model.onnx")
    torch.onnx.export(
        model,
        dummy,
        onnx_path,
        input_names=["input"],
        output_names=["logits"],
        dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
        opset_version=13,
    )
    print(f"Exported ONNX to {onnx_path}")

    # Optional: validate with ONNX Runtime
    if HAS_ONNX:
        sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
        outs = sess.run(["logits"], {"input": dummy.numpy()})
        print("ONNX runtime check OK. Output shape:", np.array(outs[0]).shape)
        meta2 = meta.copy()
        meta2["onnx_exported"] = True
        save_json(os.path.join(run_dir, "meta.json"), meta2)

# ---------------------------
# Benchmark command
# ---------------------------

def cmd_benchmark(args: argparse.Namespace):
    # Run multiple seeds quickly and aggregate performance
    results = []
    t0 = time.time()
    for i in range(args.seeds):
        seed = args.base_seed + i

        tmp_cfg = TrainConfig(
            dataset=args.dataset,
            model=args.model,
            epochs=args.epochs,
            batch_size=args.batch_size,
            lr=args.lr,
            weight_decay=0.0,
            hidden_sizes=tuple(args.hidden_sizes or [64, 64]),
            dropout=args.dropout,
            batchnorm=bool(args.batchnorm),
            seed=seed,
            device=args.device,
            deterministic=True,
            n_samples=args.n_samples,
            noise=args.noise,
            standardize=True,
            early_stop=True,
            patience=args.patience,
            min_delta=args.min_delta,
        )

        run_tag = f"{now_tag()}_{tmp_cfg.dataset}_{tmp_cfg.model}_seed{seed}"
        tmp_cfg.run_dir = os.path.join("runs", run_tag)

        # Train
        train_ns = argparse.Namespace(config=None, **tmp_cfg.as_dict())
        cmd_train(train_ns)

        # Eval
        eval_ns = argparse.Namespace(run_dir=tmp_cfg.run_dir, device=tmp_cfg.device)
        cmd_eval(eval_ns)

        # Collect metrics
        eval_res = load_json(os.path.join(tmp_cfg.run_dir, "eval.json"))
        row = {
            "run_dir": tmp_cfg.run_dir,
            "seed": seed,
            **eval_res,
        }
        results.append(row)

    # Save results
    df = pd.DataFrame(results)
    bench_dir = os.path.join("runs", f"benchmark_{args.dataset}_{args.model}_{now_tag()}")
    ensure_dir(bench_dir)
    df.to_csv(os.path.join(bench_dir, "benchmark.csv"), index=False)
    print("Benchmark summary:")
    print(df.describe(include="all"))

    # Plot distribution of test_acc
    plt.figure(figsize=(6, 4))
    sns.boxplot(y=df["test_acc"], color="#5B8FF9")
    plt.title(f"Test accuracy distribution ({args.dataset}, {args.model}, n={args.seeds})")
    plt.ylabel("test_acc")
    plt.tight_layout()
    plt.savefig(os.path.join(bench_dir, "acc_distribution.png"), dpi=150)
    plt.close()

    elapsed = time.time() - t0
    print(f"Benchmark finished in {pretty_time(elapsed)}. Artifacts: {bench_dir}")

# ---------------------------
# CLI parser
# ---------------------------

def build_parser():
    p = argparse.ArgumentParser(description="MetaIntelligence CLI: train/eval/predict/benchmark/export-onnx")
    sub = p.add_subparsers(dest="command", required=True)

    # Common options helpers
    def add_common_train(sp):
        sp.add_argument("--config", type=str, default=None, help="JSON or YAML config path")
        sp.add_argument("--dataset", type=str, choices=["moons", "blobs", "circles", "iris", "mnist"])
        sp.add_argument("--model", type=str, choices=["mlp", "cnn"])
        sp.add_argument("--data_dir", type=str, help="Data dir (MNIST)")
        sp.add_argument("--download", type=lambda x: str(x).lower() == "true", help="Download MNIST (true/false)")
        sp.add_argument("--test_size", type=float, help="Test split size (tabular)")
        sp.add_argument("--val_frac", type=float, help="Validation fraction from hold-out (tabular)")
        sp.add_argument("--batch_size", type=int)
        sp.add_argument("--epochs", type=int)
        sp.add_argument("--lr", type=float)
        sp.add_argument("--weight_decay", type=float)
        sp.add_argument("--hidden_sizes", nargs="*", type=int, help="Hidden sizes for MLP")
        sp.add_argument("--dropout", type=float)
        sp.add_argument("--batchnorm", action="store_true")
        sp.add_argument("--seed", type=int)
        sp.add_argument("--deterministic", action="store_true")
        sp.add_argument("--device", type=str, choices=["auto", "cpu", "cuda"])
        sp.add_argument("--live_plot", action="store_true")
        sp.add_argument("--standardize", action="store_true")
        sp.add_argument("--noise", type=float, help="Synthetic dataset noise")
        sp.add_argument("--n_samples", type=int, help="Synthetic dataset samples")
        sp.add_argument("--early_stop", action="store_true")
        sp.add_argument("--patience", type=int)
        sp.add_argument("--min_delta", type=float)

    sp_train = sub.add_parser("train", help="Train a model")
    add_common_train(sp_train)

    sp_eval = sub.add_parser("eval", help="Evaluate best checkpoint on test set")
    sp_eval.add_argument("--run-dir", type=str, required=True)
    sp_eval.add_argument("--device", type=str, choices=["auto", "cpu", "cuda"], default=None)

    sp_predict = sub.add_parser("predict", help="Predict on CSV (tabular only)")
    sp_predict.add_argument("--run-dir", type=str, required=True)
    sp_predict.add_argument("--input-csv", type=str, required=True)
    sp_predict.add_argument("--output-csv", type=str, required=True)
    sp_predict.add_argument("--device", type=str, choices=["auto", "cpu", "cuda"], default=None)

    sp_export = sub.add_parser("export-onnx", help="Export best checkpoint to ONNX")
    sp_export.add_argument("--run-dir", type=str, required=True)

    sp_bench = sub.add_parser("benchmark", help="Benchmark multiple seeds")
    sp_bench.add_argument("--dataset", type=str, choices=["moons", "blobs", "circles", "iris", "mnist"], required=True)
    sp_bench.add_argument("--model", type=str, choices=["mlp", "cnn"], required=True)
    sp_bench.add_argument("--epochs", type=int, default=30)
    sp_bench.add_argument("--batch_size", type=int, default=128)
    sp_bench.add_argument("--lr", type=float, default=1e-3)
    sp_bench.add_argument("--hidden_sizes", nargs="*", type=int, default=[128, 64])
    sp_bench.add_argument("--dropout", type=float, default=0.0)
    sp_bench.add_argument("--batchnorm", action="store_true")
    sp_bench.add_argument("--seeds", type=int, default=3)
    sp_bench.add_argument("--base_seed", type=int, default=42)
    sp_bench.add_argument("--device", type=str, choices=["auto", "cpu", "cuda"], default="auto")
    sp_bench.add_argument("--n_samples", type=int, default=4000)
    sp_bench.add_argument("--noise", type=float, default=0.2)
    sp_bench.add_argument("--patience", type=int, default=8)
    sp_bench.add_argument("--min_delta", type=float, default=1e-4)

    return p

# ---------------------------
# Main
# ---------------------------

def main(argv=None):
    import sys

    # Fix for Colab: ignore injected kernel path like .../kernel-XXXX.json
    raw_args = argv if argv is not None else sys.argv[1:]
    filtered_args = [a for a in raw_args if not a.endswith(".json")]

    parser = build_parser()

    # ✅ Avoid parser crash: early exit if no valid args remain
    if not filtered_args:
        parser.print_help()
        print("\n📌 Example usage (in notebooks):")
        print('main(["train", "--dataset", "moons", "--epochs", "30", "--run_dir", "runs/moons_test"])')
        print('main(["eval", "--run_dir", "runs/moons_test"])')
        print('main(["export-onnx", "--run_dir", "runs/moons_test"])')
        return

    try:
        args = parser.parse_args(filtered_args)
    except SystemExit:
        # ✅ Notebook safety: don't crash cell
        parser.print_help()
        print("\n📌 Example usage (in notebooks):")
        print('main(["train", "--dataset", "moons", "--epochs", "30", "--run_dir", "runs/moons_test"])')
        print('main(["eval", "--run_dir", "runs/moons_test"])')
        print('main(["export-onnx", "--run_dir", "runs/moons_test"])')
        return

    cmd = getattr(args, "command", None)
    if cmd == "train":
        cmd_train(args)
    elif cmd == "eval":
        cmd_eval(args)
    elif cmd == "predict":
        cmd_predict(args)
    elif cmd == "export-onnx":
        cmd_export_onnx(args)
    elif cmd == "benchmark":
        cmd_benchmark(args)
    else:
        parser.print_help()
        print("\n📌 Example usage (in notebooks):")
        print('main(["train", "--dataset", "moons", "--epochs", "30", "--run_dir", "runs/moons_test"])')
        print('main(["eval", "--run_dir", "runs/moons_test"])')
        print('main(["export-onnx", "--run_dir", "runs/moons_test"])')