<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/colab_kernel_launcher_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# filename: colab_kernel_launcher.py

import argparse
import os
import json
import math
import random
from typing import Tuple

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader


def str2bool(v: str) -> bool:
    if isinstance(v, bool):
        return v
    v = v.strip().lower()
    if v in ("yes", "true", "t", "1", "y"):
        return True
    if v in ("no", "false", "f", "0", "n"):
        return False
    raise argparse.ArgumentTypeError(f"Boolean value expected for --save_model, got: {v}")


def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # safe even if no cuda
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def make_synthetic_classification(
    n_samples: int,
    n_features: int = 3,
    random_state: int = 42,
    noise: float = 0.5,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate a simple, linearly-separable-ish binary classification dataset
    without external dependencies.
    """
    rng = np.random.default_rng(random_state)
    X = rng.normal(loc=0.0, scale=1.0, size=(n_samples, n_features))

    # Create a random linear boundary
    w = rng.normal(size=(n_features,))
    b = rng.normal()

    logits = X @ w + b + rng.normal(scale=noise, size=n_samples)
    y = (logits > 0).astype(np.float32)
    return X.astype(np.float32), y


class MLP(nn.Module):
    def __init__(self, input_dim: int = 3, hidden_dim: int = 16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),  # binary logit
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)  # shape: (batch,)


def train(
    model: nn.Module,
    loader: DataLoader,
    epochs: int,
    device: torch.device,
    lr: float = 1e-3,
) -> list:
    model.to(device)
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    history = []
    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)
            optim.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optim.step()
            running_loss += loss.item() * xb.size(0)

        epoch_loss = running_loss / len(loader.dataset)
        history.append(epoch_loss)
        print(f"Epoch {epoch:03d}/{epochs} - loss: {epoch_loss:.4f}")
    return history


def evaluate_accuracy(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)
            logits = model(xb)
            probs = torch.sigmoid(logits)
            preds = (probs >= 0.5).float()
            correct += (preds == yb).sum().item()
            total += yb.numel()
    return correct / total if total else math.nan


def ensure_outdir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def save_artifacts(
    outdir: str,
    model: nn.Module,
    config: dict,
    history: list,
    metrics: dict,
) -> None:
    ensure_outdir(outdir)
    torch.save(model.state_dict(), os.path.join(outdir, "model.pt"))
    with open(os.path.join(outdir, "run_summary.json"), "w") as f:
        json.dump(
            {
                "config": config,
                "history": history,
                "metrics": metrics,
            },
            f,
            indent=2,
        )
    print(f"Saved model and summary to: {outdir}")


def parse_args():
    parser = argparse.ArgumentParser(
        description="End-to-end demo script that trains a tiny classifier and works in Colab/Jupyter.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument("--n_samples", type=int, default=1000, help="Number of synthetic samples to generate.")
    parser.add_argument("--epochs", type=int, default=20, help="Number of training epochs.")
    parser.add_argument("--batch_size", type=int, default=32, help="Mini-batch size.")
    parser.add_argument("--save_model", type=str, default="true", help="Whether to save model artifacts (true/false).")
    parser.add_argument(
        "--test_input",
        type=float,
        nargs=3,
        metavar=("X1", "X2", "X3"),
        help="Three feature values to run a single prediction after training.",
    )
    parser.add_argument("--outdir", type=str, default="outputs", help="Directory to save artifacts.")
    parser.add_argument("--random_state", type=int, default=42, help="Random seed for reproducibility.")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate.")
    parser.add_argument("--hidden_dim", type=int, default=16, help="Hidden layer size for the MLP.")

    # Important for Jupyter/Colab: ignore unknown args like "-f <kernel.json>"
    args, _ = parser.parse_known_args()
    # Normalize boolean
    args.save_model = str2bool(args.save_model)
    return args


def main():
    args = parse_args()
    set_seed(args.random_state)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    n_features = 3  # fixed to align with --test_input of length 3
    X, y = make_synthetic_classification(
        n_samples=args.n_samples,
        n_features=n_features,
        random_state=args.random_state,
        noise=0.5,
    )

    # Train/val split (80/20)
    n_train = int(0.8 * len(X))
    X_train, y_train = X[:n_train], y[:n_train]
    X_val, y_val = X[n_train:], y[n_train:]

    train_ds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
    val_ds = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, drop_last=False)

    model = MLP(input_dim=n_features, hidden_dim=args.hidden_dim)
    history = train(model, train_loader, epochs=args.epochs, device=device, lr=args.lr)
    val_acc = evaluate_accuracy(model, val_loader, device=device)

    print(f"Validation accuracy: {val_acc:.4f}")

    # Optional single prediction from --test_input
    test_pred = None
    if args.test_input is not None:
        test_vec = torch.tensor(args.test_input, dtype=torch.float32, device=device).unsqueeze(0)
        model.eval()
        with torch.no_grad():
            logit = model(test_vec)
            prob = torch.sigmoid(logit).item()
            label = 1 if prob >= 0.5 else 0
            test_pred = {"input": list(map(float, args.test_input)), "probability": float(prob), "predicted_label": int(label)}
        print(f"Test input: {test_pred['input']} -> prob={test_pred['probability']:.4f}, label={test_pred['predicted_label']}")

    config = {
        "n_samples": args.n_samples,
        "epochs": args.epochs,
        "batch_size": args.batch_size,
        "save_model": args.save_model,
        "outdir": args.outdir,
        "random_state": args.random_state,
        "lr": args.lr,
        "hidden_dim": args.hidden_dim,
        "device": str(device),
    }
    metrics = {"val_accuracy": float(val_acc), "final_loss": float(history[-1]) if history else float("nan")}
    if test_pred is not None:
        metrics["test_prediction"] = test_pred

    if args.save_model:
        save_artifacts(args.outdir, model, config=config, history=history, metrics=metrics)
    else:
        print("Skipping artifact save (save_model=false).")

    print("Done.")


if __name__ == "__main__":
    main()