<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/unifiedai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch --quiet

In [None]:
#!/usr/bin/env python3
"""
UnifiedAI: Train a conditional VAE-style dynamics model on synthetic data,
then plan toward a goal via random shooting. Notebook/CLI friendly.

- Ignores unknown args (e.g., Jupyter's -f file.json)
- Deterministic toggle and CUDA fallback
- Provenance: config hash, metrics CSV, checkpoints
- Modes: run (train+plan) or grid (sweep seeds/epochs/plan_candidates)
"""

import argparse, math, random, os, json, time, hashlib, csv
from dataclasses import dataclass, asdict
from typing import Tuple, Dict, Any, List
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

# ----------------------------
# Utilities
# ----------------------------
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def set_deterministic(enabled: bool):
    torch.backends.cudnn.benchmark = not enabled
    torch.backends.cudnn.deterministic = enabled

def ensure_dir(path: str):
    if path and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)

def now_tag() -> str:
    return datetime.utcnow().strftime("%Y%m%d-%H%M%S")

def cfg_hash(cfg_dict: Dict[str, Any]) -> str:
    payload = json.dumps(cfg_dict, sort_keys=True, separators=(",", ":")).encode("utf-8")
    return hashlib.sha256(payload).hexdigest()[:12]

def write_json(path: str, obj: Dict[str, Any]):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, sort_keys=True)

def append_csv(path: str, header: List[str], row: Dict[str, Any]):
    is_new = not os.path.exists(path)
    with open(path, "a", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=header, extrasaction="ignore")
        if is_new: w.writeheader()
        w.writerow(row)

def human_time(sec: float) -> str:
    if sec < 1: return f"{sec*1000:.0f} ms"
    if sec < 60: return f"{sec:.1f} s"
    m,s = divmod(sec,60); return f"{int(m)}m{int(s):02d}s"

# ----------------------------
# Config
# ----------------------------
@dataclass
class Config:
    # Model/data
    state_dim: int = 8
    action_dim: int = 4
    latent_dim: int = 16
    hidden_dim: int = 128
    # Train
    epochs: int = 1
    batch_size: int = 64
    steps_per_epoch: int = 200
    lr: float = 3e-3
    kl_warmup_epochs: int = 1
    # Runtime
    seed: int = 1337
    device: str = "cpu"          # cpu|cuda
    deterministic: bool = True
    # Planning
    plan_depth: int = 12
    plan_batch: int = 8
    plan_candidates: int = 64
    action_scale: float = 1.0
    goal: str = "ones"           # ones|zeros|random
    # I/O
    output_dir: str = ""
    save_checkpoint: bool = True
    # Derived
    run_tag: str = ""
    config_hash: str = ""

# ----------------------------
# Synthetic Dynamics Dataset
# ----------------------------
class SyntheticDynamics(Dataset):
    def __init__(self, state_dim, action_dim, steps_per_epoch, batch_size, traj_len=16, noise_std=0.01, seed=1234):
        super().__init__()
        rng = np.random.default_rng(seed)
        # Stable A
        A = rng.standard_normal((state_dim, state_dim)) * 0.3
        u, s, vh = np.linalg.svd(A, full_matrices=False)
        s = np.clip(s, 0, 0.9)
        self.A = torch.tensor((u @ np.diag(s) @ vh).astype(np.float32))
        self.B = torch.tensor((rng.standard_normal((state_dim, action_dim)) * 0.5).astype(np.float32))

        total_pairs = steps_per_epoch * batch_size
        num_trajs = math.ceil(total_pairs / (traj_len - 1))
        s_list, a_list, sp_list = [], [], []
        for _ in range(num_trajs):
            s = torch.from_numpy(rng.standard_normal(state_dim).astype(np.float32))
            for _t in range(traj_len - 1):
                a = torch.from_numpy(rng.standard_normal(action_dim).astype(np.float32))
                s_next = torch.tanh(self.A @ s + self.B @ a) + torch.randn_like(s) * noise_std
                s_list.append(s); a_list.append(a); sp_list.append(s_next)
                s = s_next
        self.s = torch.stack(s_list)[:total_pairs]
        self.a = torch.stack(a_list)[:total_pairs]
        self.sp = torch.stack(sp_list)[:total_pairs]

    def __len__(self): return self.s.shape[0]
    def __getitem__(self, idx): return self.s[idx], self.a[idx], self.sp[idx]

# ----------------------------
# Model
# ----------------------------
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim), nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x): return self.net(x)

class CondVAE(nn.Module):
    def __init__(self, state_dim, action_dim, latent_dim=16, hidden_dim=128):
        super().__init__()
        self.encoder = MLP(state_dim + action_dim + state_dim, 2*latent_dim, hidden_dim)
        self.prior   = MLP(state_dim + action_dim,           2*latent_dim, hidden_dim)
        self.decoder = MLP(state_dim + action_dim + latent_dim, state_dim, hidden_dim)
    def posterior(self, s,a,sp): return torch.chunk(self.encoder(torch.cat([s,a,sp],-1)),2,-1)
    def prior_params(self,s,a):  return torch.chunk(self.prior(torch.cat([s,a],-1)),2,-1)
    def reparam(self, mu, logv): return mu + torch.randn_like(mu)*(0.5*logv).exp()
    def kl_gauss(self, mu_q, lv_q, mu_p, lv_p):
        v_q, v_p = lv_q.exp(), lv_p.exp()
        kl = 0.5*(lv_p - lv_q - 1 + (v_q + (mu_q - mu_p)**2)/v_p)
        return kl.sum(-1)
    def decode(self,s,a,z): return self.decoder(torch.cat([s,a,z],-1))
    def loss(self,s,a,sp,kl_w=1.0):
        mu_q,lv_q = self.posterior(s,a,sp)
        mu_p,lv_p = self.prior_params(s,a)
        z = self.reparam(mu_q,lv_q)
        sp_pred = self.decode(s,a,z)
        recon = F.mse_loss(sp_pred, sp, reduction='none').mean(-1)
        kl = self.kl_gauss(mu_q,lv_q,mu_p,lv_p)
        total = recon.mean() + kl_w*kl.mean()
        return total, {"loss": total, "recon": recon.mean(), "kl": kl.mean()}
    @torch.no_grad()
    def rollout_mean_latent(self,s0,acts):
        T,B,_ = acts.shape
        states=[s0]; s=s0
        for t in range(T):
            mu_p,_ = self.prior_params(s,acts[t])
            sp=self.decode(s,acts[t],mu_p)
            states.append(sp); s=sp
        return torch.stack(states,0)

# ----------------------------
# Train & Plan
# ----------------------------
def train_model(cfg: Config):
    device = torch.device(cfg.device)
    ds = SyntheticDynamics(cfg.state_dim, cfg.action_dim, cfg.steps_per_epoch, cfg.batch_size, seed=cfg.seed)
    dl = DataLoader(ds, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
    model = CondVAE(cfg.state_dim, cfg.action_dim, cfg.latent_dim, cfg.hidden_dim).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    t0 = time.time()
    for ep in range(cfg.epochs):
        pbar = tqdm(dl, desc=f"train[{ep+1}/{cfg.epochs}]", leave=False)
        for s,a,sp in pbar:
            s,a,sp = s.to(device), a.to(device), sp.to(device)
            kl_w = min(1.0, (ep+1)/max(1,cfg.kl_warmup_epochs))
            opt.zero_grad(set_to_none=True)
            loss, st = model.loss(s,a,sp,kl_w)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            pbar.set_postfix({"loss": f"{st['loss'].item():.3f}",
                              "recon": f"{st['recon'].item():.4f}",
                              "kl": f"{st['kl'].item():.2f}",
                              "kl_w": f"{kl_w:.2f}"})
    t1 = time.time()

    metrics = {"loss": float(st["loss"]), "recon": float(st["recon"]), "kl": float(st["kl"])}
    train_time = t1 - t0
    return model, metrics, train_time

@torch.no_grad()
def plan(model: CondVAE, cfg: Config):
    device=torch.device(cfg.device)
    if cfg.goal=='ones': goal=torch.ones(cfg.state_dim,device=device)
    elif cfg.goal=='zeros': goal=torch.zeros(cfg.state_dim,device=device)
    else: goal=torch.randn(cfg.state_dim,device=device)

    s0=torch.zeros(cfg.plan_batch,cfg.state_dim,device=device)
    cand_actions=torch.randn(cfg.plan_depth,cfg.plan_candidates,cfg.action_dim,device=device)*cfg.action_scale

    t0=time.time()
    scores=[]
    for k in range(cfg.plan_candidates):
        acts_k=cand_actions[:,k].unsqueeze(1)  # (T,1,A)
        s_single=s0[0:1]
        final_state=model.rollout_mean_latent(s_single,acts_k)[-1,0]
        scores.append(((final_state-goal)**2).mean())
    scores=torch.stack(scores)
    best_idx=torch.topk(-scores,k=1).indices.item()

    best=cand_actions[:,best_idx].unsqueeze(1).expand(-1,cfg.plan_batch,-1).contiguous()
    final_states=model.rollout_mean_latent(s0,best)[-1]
    mse=((final_states-goal)**2).mean().item()
    t1=time.time()
    return mse, {"best_idx": best_idx}, (t1-t0)

# ----------------------------
# Grid runner
# ----------------------------
def run_grid(base_cfg: Config, seeds: List[int], plan_candidates_list: List[int], epochs_list: List[int]):
    out_dir = base_cfg.output_dir or os.path.join("runs", f"grid-{now_tag()}")
    ensure_dir(out_dir)
    metrics_path = os.path.join(out_dir, "metrics.csv")
    header = [
        "timestamp","config_hash","device","deterministic","seed",
        "epochs","steps_per_epoch","batch_size","lr","kl_warmup_epochs",
        "plan_depth","plan_batch","plan_candidates","action_scale","goal",
        "loss","recon","kl","train_time","plan_time","planner_mse","checkpoint"
    ]
    print(f"[grid] Output: {out_dir}")
    for seed in seeds:
        for pc in plan_candidates_list:
            for ep in epochs_list:
                cfg = base_cfg
                cfg.seed = seed; cfg.plan_candidates = pc; cfg.epochs = ep
                cfg.run_tag = now_tag()
                cd = asdict(cfg).copy(); cd.pop("run_tag", None); cd.pop("config_hash", None)
                cfg.config_hash = cfg_hash(cd)
                run_dir = os.path.join(out_dir, f"{cfg.config_hash}-{cfg.run_tag}")
                ensure_dir(run_dir)

                set_deterministic(cfg.deterministic); set_seed(cfg.seed)
                if cfg.device == "cuda" and not torch.cuda.is_available():
                    print("[warn] CUDA not available, falling back to CPU"); cfg.device = "cpu"

                model, tr_metrics, train_time = train_model(cfg)
                mse, info, plan_time = plan(model, cfg)

                ckpt_path = ""
                if cfg.save_checkpoint:
                    ckpt_path = os.path.join(run_dir, "model.pt")
                    torch.save({"model": model.state_dict(), "cfg": asdict(cfg)}, ckpt_path)
                    write_json(os.path.join(run_dir, "config.json"), asdict(cfg))

                row = {
                    "timestamp": cfg.run_tag,
                    "config_hash": cfg.config_hash,
                    "device": cfg.device,
                    "deterministic": cfg.deterministic,
                    "seed": cfg.seed,
                    "epochs": cfg.epochs,
                    "steps_per_epoch": cfg.steps_per_epoch,
                    "batch_size": cfg.batch_size,
                    "lr": cfg.lr,
                    "kl_warmup_epochs": cfg.kl_warmup_epochs,
                    "plan_depth": cfg.plan_depth,
                    "plan_batch": cfg.plan_batch,
                    "plan_candidates": cfg.plan_candidates,
                    "action_scale": cfg.action_scale,
                    "goal": cfg.goal,
                    "loss": tr_metrics.get("loss", float("nan")),
                    "recon": tr_metrics.get("recon", float("nan")),
                    "kl": tr_metrics.get("kl", float("nan")),
                    "train_time": train_time,
                    "plan_time": plan_time,
                    "planner_mse": mse,
                    "checkpoint": ckpt_path,
                }
                append_csv(metrics_path, header, row)
                print(f"[grid][{cfg.config_hash}] seed={cfg.seed} epochs={cfg.epochs} cand={cfg.plan_candidates} "
                      f"| loss={row['loss']:.3f} mse={mse:.4f}")

    print(f"[grid] Done. Metrics CSV at: {metrics_path}")
    return metrics_path

# ----------------------------
# CLI
# ----------------------------
def build_parser():
    p = argparse.ArgumentParser(description="UnifiedAI (Notebook/CLI friendly)", add_help=True)
    # Modes
    p.add_argument("--mode", choices={"run", "grid"}, default="run", help="run once or grid sweep")
    # Model/data
    p.add_argument("--state_dim", type=int, default=8)
    p.add_argument("--action_dim", type=int, default=4)
    p.add_argument("--latent_dim", type=int, default=16)
    p.add_argument("--hidden_dim", type=int, default=128)
    # Train
    p.add_argument("--epochs", type=int, default=1)
    p.add_argument("--batch_size", type=int, default=64)
    p.add_argument("--steps_per_epoch", type=int, default=200)
    p.add_argument("--lr", type=float, default=3e-3)
    p.add_argument("--kl_warmup_epochs", type=int, default=1)
    # Runtime
    p.add_argument("--seed", type=int, default=1337)
    p.add_argument("--device", choices={"cpu","cuda"}, default="cpu")
    p.add_argument("--deterministic", type=lambda x: str(x).lower() in {"1","true","yes","y"}, default=True)
    # Planning
    p.add_argument("--plan_depth", type=int, default=12)
    p.add_argument("--plan_batch", type=int, default=8)
    p.add_argument("--plan_candidates", type=int, default=64)
    p.add_argument("--action_scale", type=float, default=1.0)
    p.add_argument("--goal", choices={"ones","zeros","random"}, default="ones")
    # I/O
    p.add_argument("--output_dir", type=str, default="")
    p.add_argument("--save_checkpoint", type=lambda x: str(x).lower() in {"1","true","yes","y"}, default=True)
    # Grid params
    p.add_argument("--grid_seeds", type=str, default="")
    p.add_argument("--grid_plan_candidates", type=str, default="")
    p.add_argument("--grid_epochs", type=str, default="")
    return p

def args_to_cfg(args: argparse.Namespace) -> Config:
    cfg = Config(
        state_dim=args.state_dim, action_dim=args.action_dim,
        latent_dim=args.latent_dim, hidden_dim=args.hidden_dim,
        epochs=args.epochs, batch_size=args.batch_size, steps_per_epoch=args.steps_per_epoch,
        lr=args.lr, kl_warmup_epochs=args.kl_warmup_epochs,
        seed=args.seed, device=args.device, deterministic=args.deterministic,
        plan_depth=args.plan_depth, plan_batch=args.plan_batch,
        plan_candidates=args.plan_candidates, action_scale=args.action_scale, goal=args.goal,
        output_dir=args.output_dir, save_checkpoint=args.save_checkpoint,
    )
    cfg.run_tag = now_tag()
    cd = asdict(cfg).copy(); cd.pop("run_tag", None); cd.pop("config_hash", None)
    cfg.config_hash = cfg_hash(cd)
    return cfg

def parse_list_arg(text: str, fallback: List[int]) -> List[int]:
    if not text: return fallback
    return [int(x.strip()) for x in text.split(",") if x.strip()]

def main():
    parser = build_parser()
    args, _unknown = parser.parse_known_args()  # notebook-safe
    if args.device == "cuda" and not torch.cuda.is_available():
        print("[warn] CUDA not available, falling back to CPU"); args.device = "cpu"

    cfg = args_to_cfg(args)
    set_deterministic(cfg.deterministic); set_seed(cfg.seed)

    if not cfg.output_dir:
        cfg.output_dir = os.path.join("runs", f"{'grid' if args.mode=='grid' else 'run'}-{cfg.config_hash}-{cfg.run_tag}")
    ensure_dir(cfg.output_dir)

    print(f"[config] {cfg}")

    if args.mode == "grid":
        seeds = parse_list_arg(args.grid_seeds, [cfg.seed])
        cand_list = parse_list_arg(args.grid_plan_candidates, [cfg.plan_candidates])
        epochs_list = parse_list_arg(args.grid_epochs, [cfg.epochs])
        run_grid(cfg, seeds, cand_list, epochs_list)
        print(f"[summary] GRID COMPLETE -> {cfg.output_dir}")
        return

    # Single run
    t0 = time.time()
    model, tr_metrics, train_time = train_model(cfg)
    mse, info, plan_time = plan(model, cfg)

    ckpt_path = ""
    if cfg.save_checkpoint:
        ckpt_path = os.path.join(cfg.output_dir, "model.pt")
        torch.save({"model": model.state_dict(), "cfg": asdict(cfg)}, ckpt_path)
        write_json(os.path.join(cfg.output_dir, "config.json"), asdict(cfg))

    metrics_path = os.path.join(cfg.output_dir, "metrics.csv")
    header = [
        "timestamp","config_hash","device","deterministic","seed",
        "epochs","steps_per_epoch","batch_size","lr","kl_warmup_epochs",
        "plan_depth","plan_batch","plan_candidates","action_scale","goal",
        "loss","recon","kl","train_time","plan_time","planner_mse","checkpoint"
    ]
    row = {
        "timestamp": cfg.run_tag, "config_hash": cfg.config_hash,
        "device": cfg.device, "deterministic": cfg.deterministic, "seed": cfg.seed,
        "epochs": cfg.epochs, "steps_per_epoch": cfg.steps_per_epoch, "batch_size": cfg.batch_size,
        "lr": cfg.lr, "kl_warmup_epochs": cfg.kl_warmup_epochs,
        "plan_depth": cfg.plan_depth, "plan_batch": cfg.plan_batch,
        "plan_candidates": cfg.plan_candidates, "action_scale": cfg.action_scale, "goal": cfg.goal,
        "loss": tr_metrics.get("loss", float("nan")),
        "recon": tr_metrics.get("recon", float("nan")),
        "kl": tr_metrics.get("kl", float("nan")),
        "train_time": train_time, "plan_time": plan_time, "planner_mse": mse,
        "checkpoint": ckpt_path,
    }
    append_csv(metrics_path, header, row)

    print(f"[planner] Best idx={info['best_idx']} Mean final MSE={mse:.4f}")
    total_time = time.time() - t0
    print(f"[summary] RUN COMPLETE | train={human_time(train_time)} plan={human_time(plan_time)} total={human_time(total_time)} "
          f"| loss={row['loss']:.3f} mse={mse:.4f} | out={cfg.output_dir}")

if __name__ == "__main__":
    main()