<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/colab_kernel_launcher_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
colab_kernel_launcher.py

Notebook/Colab-safe CLI training script with:
- Modes: self, supervised, hybrid
- Robust argparse that ignores Jupyter's injected -f kernel.json
- Deterministic seeding
- CSV logging, matplotlib plot, and model checkpoint
- Simple synthetic dataset and a compact MLP classifier
"""

import argparse
import csv
import json
import math
import os
import random
import sys
import time
from dataclasses import asdict, dataclass
from typing import List, Optional, Tuple

# Optional plotting (gracefully skipped if matplotlib is not available)
try:
    import matplotlib.pyplot as plt
    _HAS_MPL = True
except Exception:
    _HAS_MPL = False

import numpy as np

# Optional: If torch isn't available in your environment, install it first.
import torch
import torch.nn as nn
import torch.nn.functional as F


# ----------------------------
# Utilities
# ----------------------------

def sanitize_argv(argv: Optional[List[str]] = None) -> List[str]:
    """
    Remove Jupyter/Colab injected args like:
      -f /root/.local/share/jupyter/runtime/kernel-xxxx.json
    and any stray JSON path that might follow a flag.
    """
    if argv is None:
        argv = sys.argv[1:]
    cleaned = []
    skip_next = False
    for i, a in enumerate(argv):
        if skip_next:
            skip_next = False
            continue
        if a == "-f":
            # Skip the -f and its value (usually a kernel json)
            skip_next = True
            continue
        if a.endswith(".json") and "jupyter" in a or "kernel" in a:
            # Defensive: drop any lonely kernel json arg
            continue
        cleaned.append(a)
    return cleaned


def parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Notebook-safe training launcher with self/supervised/hybrid modes"
    )
    parser.add_argument("--mode", type=str, choices=["self", "supervised", "hybrid"], default="supervised",
                        help="Training mode")
    parser.add_argument("--steps", type=int, default=500, help="Number of optimization steps")
    parser.add_argument("--batch-size", type=int, default=128, help="Batch size")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--wd", type=float, default=0.0, help="Weight decay")
    parser.add_argument("--entropy-bonus", type=float, default=0.0,
                        help="Coefficient for entropy bonus (encourages higher output entropy if > 0)")
    parser.add_argument("--label-sharpen", type=float, default=0.0,
                        help="Label smoothing factor in [0, 1). 0=one-hot, >0 mixes with uniform")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--outdir", type=str, default="runs", help="Base output directory")
    parser.add_argument("--log-every", type=int, default=50, help="Log every N steps")
    parser.add_argument("--cycle", type=str, default="0", help="Optional cycle tag for output grouping")

    # Notebook-safety: ignore unknowns to avoid SystemExit
    cleaned = sanitize_argv(argv)
    args, _unknown = parser.parse_known_args(cleaned)
    return args


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Determinism hints (may slightly reduce performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def timestamp() -> str:
    return time.strftime("%Y%m%d-%H%M%S")


def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


# ----------------------------
# Data generation (synthetic)
# ----------------------------

@dataclass
class DataConfig:
    dim: int = 2
    n_classes: int = 2
    radius: float = 3.0
    spread: float = 1.0


def sample_blobs(batch_size: int, cfg: DataConfig, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Create two Gaussian blobs in 2D (or higher dim with padding noise).
    Returns:
      x: [B, dim], y: [B]
    """
    half = batch_size // 2
    rem = batch_size - half

    # Centers on a circle
    angle0, angle1 = 0.0, math.pi
    c0 = np.array([cfg.radius * math.cos(angle0), cfg.radius * math.sin(angle0)])
    c1 = np.array([cfg.radius * math.cos(angle1), cfg.radius * math.sin(angle1)])

    x0 = np.random.randn(half, 2) * cfg.spread + c0
    x1 = np.random.randn(rem, 2) * cfg.spread + c1

    x = np.concatenate([x0, x1], axis=0)
    if cfg.dim > 2:
        extra = np.random.randn(batch_size, cfg.dim - 2) * (0.5 * cfg.spread)
        x = np.hstack([x, extra])

    y = np.concatenate([np.zeros(half, dtype=np.int64), np.ones(rem, dtype=np.int64)], axis=0)

    # Shuffle
    idx = np.random.permutation(batch_size)
    x = x[idx]
    y = y[idx]

    x = torch.tensor(x, dtype=torch.float32, device=device)
    y = torch.tensor(y, dtype=torch.long, device=device)
    return x, y


def augment_noise(x: torch.Tensor, sigma: float = 0.2) -> torch.Tensor:
    return x + sigma * torch.randn_like(x)


# ----------------------------
# Model
# ----------------------------

class MLP(nn.Module):
    def __init__(self, in_dim: int, n_classes: int, hidden: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, hidden),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, n_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


# ----------------------------
# Loss helpers
# ----------------------------

def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
    probs = logits.softmax(dim=-1)
    logp = (probs + 1e-12).log()
    ent = -(probs * logp).sum(dim=-1)
    return ent  # shape [B]


def cross_entropy_soft_targets(logits: torch.Tensor, target_probs: torch.Tensor) -> torch.Tensor:
    """
    Cross-entropy with soft targets:
      L = - sum_k target_probs[k] * log_softmax(logits)[k]
    """
    logp = logits.log_softmax(dim=-1)
    loss = -(target_probs * logp).sum(dim=-1).mean()
    return loss


def one_hot(num_classes: int, labels: torch.Tensor) -> torch.Tensor:
    y = F.one_hot(labels, num_classes=num_classes).float()
    return y


def apply_label_smoothing(target_one_hot: torch.Tensor, smoothing: float) -> torch.Tensor:
    """
    Mix one-hot with uniform distribution. smoothing in [0, 1).
      target = (1 - smoothing) * one_hot + smoothing * uniform
    """
    if smoothing <= 0.0:
        return target_one_hot
    num_classes = target_one_hot.shape[-1]
    uniform = torch.full_like(target_one_hot, 1.0 / num_classes)
    return (1.0 - smoothing) * target_one_hot + smoothing * uniform


def consistency_kl(logits_a: torch.Tensor, logits_b: torch.Tensor) -> torch.Tensor:
    """
    Symmetric KL between predictions of two augmentations.
    """
    pa = logits_a.softmax(dim=-1).clamp_min(1e-8)
    pb = logits_b.softmax(dim=-1).clamp_min(1e-8)
    log_pa = pa.log()
    log_pb = pb.log()
    kl_ab = (pa * (log_pa - log_pb)).sum(dim=-1)
    kl_ba = (pb * (log_pb - log_pa)).sum(dim=-1)
    return 0.5 * (kl_ab + kl_ba).mean()


# ----------------------------
# Training
# ----------------------------

@dataclass
class TrainConfig:
    mode: str
    steps: int
    batch_size: int
    lr: float
    wd: float
    entropy_bonus: float
    label_sharpen: float
    seed: int
    outdir: str
    log_every: int
    cycle: str
    dim: int = 2
    n_classes: int = 2
    device: str = "cuda_if_available"

    def device_obj(self) -> torch.device:
        if self.device == "cuda_if_available":
            return torch.device("cuda" if torch.cuda.is_available() else "cpu")
        return torch.device(self.device)


def train(cfg: TrainConfig) -> None:
    set_seed(cfg.seed)

    device = cfg.device_obj()
    data_cfg = DataConfig(dim=cfg.dim, n_classes=cfg.n_classes)

    # Output directory: runs/<cycle>/<timestamp>_<mode>_seed<seed>
    base = os.path.join(cfg.outdir, f"cycle-{cfg.cycle}")
    run_dir = os.path.join(base, f"{timestamp()}_{cfg.mode}_seed{cfg.seed}")
    ensure_dir(run_dir)

    # Save config
    with open(os.path.join(run_dir, "config.json"), "w") as f:
        json.dump(asdict(cfg), f, indent=2)

    model = MLP(in_dim=cfg.dim, n_classes=cfg.n_classes).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)

    # Logging setup
    csv_path = os.path.join(run_dir, "metrics.csv")
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "step",
            "mode",
            "loss_total",
            "loss_sup",
            "loss_ssl",
            "entropy",
            "lr",
        ])

    loss_hist = []
    sup_hist = []
    ssl_hist = []
    ent_hist = []

    for step in range(1, cfg.steps + 1):
        model.train()
        x, y = sample_blobs(cfg.batch_size, data_cfg, device)

        logits = model(x)

        loss_sup = torch.tensor(0.0, device=device)
        loss_ssl = torch.tensor(0.0, device=device)

        # Supervised branch
        if cfg.mode in ("supervised", "hybrid"):
            y_1h = one_hot(cfg.n_classes, y)
            y_soft = apply_label_smoothing(y_1h, cfg.label_sharpen)
            loss_sup = cross_entropy_soft_targets(logits, y_soft)

        # Self-supervised branch (consistency across two noisy views)
        if cfg.mode in ("self", "hybrid"):
            x_a = augment_noise(x, 0.25)
            x_b = augment_noise(x, 0.25)
            logits_a = model(x_a)
            logits_b = model(x_b)
            loss_ssl = consistency_kl(logits_a, logits_b)

        # Entropy bonus (encourage exploration by increasing output entropy)
        ent = entropy_from_logits(logits).mean()
        total = loss_sup + loss_ssl - cfg.entropy_bonus * ent

        opt.zero_grad(set_to_none=True)
        total.backward()
        opt.step()

        # Log
        if (step % cfg.log_every) == 0 or step == 1 or step == cfg.steps:
            lr_cur = opt.param_groups[0]["lr"]
            loss_hist.append(float(total.detach().cpu()))
            sup_hist.append(float(loss_sup.detach().cpu()))
            ssl_hist.append(float(loss_ssl.detach().cpu()))
            ent_hist.append(float(ent.detach().cpu()))

            with open(csv_path, "a", newline="") as f:
                writer = csv.writer(f)
                writer.writerow([
                    step, cfg.mode,
                    f"{loss_hist[-1]:.6f}",
                    f"{sup_hist[-1]:.6f}",
                    f"{ssl_hist[-1]:.6f}",
                    f"{ent_hist[-1]:.6f}",
                    f"{lr_cur:.8f}",
                ])

            print(f"[{step:5d}/{cfg.steps}] "
                  f"mode={cfg.mode} "
                  f"loss={loss_hist[-1]:.4f} "
                  f"sup={sup_hist[-1]:.4f} "
                  f"ssl={ssl_hist[-1]:.4f} "
                  f"ent={ent_hist[-1]:.4f} "
                  f"lr={lr_cur:.2e}")

    # Save model
    ckpt_path = os.path.join(run_dir, "model.pt")
    torch.save({"model_state": model.state_dict(), "config": asdict(cfg)}, ckpt_path)

    # Plot metrics
    if _HAS_MPL and len(loss_hist) > 0:
        fig, ax = plt.subplots(1, 1, figsize=(7, 4))
        xs = list(range(1, len(loss_hist) + 1))
        ax.plot(xs, loss_hist, label="total")
        ax.plot(xs, sup_hist, label="supervised")
        ax.plot(xs, ssl_hist, label="self")
        ax.plot(xs, ent_hist, label="entropy")
        ax.set_title(f"Training - {cfg.mode}")
        ax.set_xlabel(f"Logged steps (every {cfg.log_every})")
        ax.set_ylabel("Value")
        ax.grid(True, alpha=0.3)
        ax.legend()
        fig.tight_layout()
        fig_path = os.path.join(run_dir, "metrics.png")
        fig.savefig(fig_path, dpi=140)
        plt.close(fig)

    print(f"Done. Outputs saved to: {run_dir}")


# ----------------------------
# Entry points
# ----------------------------

def main(argv: Optional[List[str]] = None) -> None:
    args = parse_args(argv)
    cfg = TrainConfig(
        mode=args.mode,
        steps=args.steps,
        batch_size=args.batch_size,
        lr=args.lr,
        wd=args.wd,
        entropy_bonus=args.entropy_bonus,
        label_sharpen=args.label_sharpen,
        seed=args.seed,
        outdir=args.outdir,
        log_every=args.log_every,
        cycle=str(args.cycle),
    )
    train(cfg)


if __name__ == "__main__":
    # This is safe in notebooks too; sanitize_argv prevents -f kernel.json crashes.
    try:
        main()
    except SystemExit as e:
        # Mirror IPython friendliness
        print("Use 'exit', 'quit', or Ctrl-D to exit.", file=sys.stderr)
        raise