<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/unifiedai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch --quiet

In [None]:
#!/usr/bin/env python3
"""
unifiedai.py
============
End-to-end unified AI scaffold:
Perception → Fusion → Episodic Memory → RSSM → Planner
Includes synthetic training loop + planner demo.
"""

from __future__ import annotations
import argparse, math, random
from typing import List, Optional, Tuple, Callable, Union, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import trange

# --- Utils ---

def default_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def reparam_gauss(stats: torch.Tensor, rng: Optional[torch.Generator] = None) -> torch.Tensor:
    mu, logvar = torch.chunk(stats, 2, dim=-1)
    std = (0.5 * logvar).exp()
    eps = torch.randn_like(std, generator=rng)
    return mu + eps * std

def kl_gauss(post: torch.Tensor, prior: torch.Tensor, reduce: bool = True) -> torch.Tensor:
    mu_p, lv_p = torch.chunk(post, 2, dim=-1)
    mu_q, lv_q = torch.chunk(prior, 2, dim=-1)
    kl = 0.5 * (lv_q - lv_p + (lv_p.exp() + (mu_p - mu_q).pow(2)) / lv_q.exp() - 1.0).sum(-1)
    return kl.mean() if reduce else kl

def maybe_norm(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    return x / (x.norm(dim=-1, keepdim=True) + eps)

def pair_rotate(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    x1, x2 = x[..., ::2], x[..., 1::2]
    y1 = x1 * cos - x2 * sin
    y2 = x1 * sin + x2 * cos
    return torch.stack([y1, y2], dim=-1).flatten(-2)

def build_rope_cache(seq_len: int, dim: int, base: float = 10000.0, device=None, dtype=None):
    half = dim // 2
    idx = torch.arange(half, device=device, dtype=dtype)
    freqs = 1.0 / (base ** (idx / half))
    t = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(1)
    angles = t * freqs.unsqueeze(0)
    cos = torch.cos(angles).repeat_interleave(2, dim=-1)
    sin = torch.sin(angles).repeat_interleave(2, dim=-1)
    return cos, sin

# --- Perception ---

class MockVisionEncoder(nn.Module):
    def __init__(self, embed_dim: int = 1024, tokens: int = 16):
        super().__init__()
        self.tokens = tokens
        self.mlp = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.GELU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.GELU(),
            nn.AdaptiveAvgPool2d((tokens, 1)),
        )
        self.proj = nn.Linear(64, embed_dim)
    def forward(self, img: torch.Tensor) -> torch.Tensor:
        x = self.mlp(img)
        x = x.squeeze(-1).transpose(1, 2)
        return self.proj(x)

class MockTextEncoder(nn.Module):
    def __init__(self, vocab_size: int = 32000, embed_dim: int = 1024, max_len: int = 128):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.max_len = max_len
    def forward(self, texts: Union[str, List[str]]):
        if isinstance(texts, str):
            texts = [texts]
        toks = [[(ord(c) % 255) + 1 for c in t][: self.max_len] for t in texts]
        max_t = max(len(x) for x in toks) if toks else 1
        pad_id = 0
        ids = torch.full((len(toks), max_t), pad_id, dtype=torch.long)
        mask = torch.zeros((len(toks), max_t), dtype=torch.bool)
        for i, seq in enumerate(toks):
            L = len(seq)
            ids[i, :L] = torch.tensor(seq, dtype=torch.long)
            mask[i, :L] = True
        x = self.embed(ids.to(self.embed.weight.device))
        pos = torch.arange(max_t, device=x.device).float()
        pe = torch.stack([torch.sin(pos / 10000**(2*k/x.size(-1))) if k % 2 == 0 else torch.cos(pos / 10000**(2*(k-1)/x.size(-1))) for k in range(x.size(-1))], dim=-1)
        x = x + pe.unsqueeze(0)
        return x, mask.to(x.device)

# --- Fusion + Heads ---

class TokenPool(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.q = nn.Parameter(torch.randn(dim))
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        attn = (x * self.q).sum(-1)
        if mask is not None:
            attn = attn.masked_fill(~mask, -1e9)
        w = attn.softmax(-1)
        return (w.unsqueeze(-1) * x).sum(1)

class FusionObsEncoder(nn.Module):
    def __init__(self, embed_dim: int = 1024, obs_dim: int = 1024):
        super().__init__()
        self.v_pool = TokenPool(embed_dim)
        self.t_pool = TokenPool(embed_dim)
        self.mlp = nn.Sequential(
            nn.LayerNorm(embed_dim * 2),
            nn.Linear(embed_dim * 2, obs_dim),
            nn.GELU(),
            nn.Linear(obs_dim, obs_dim),
        )
    def forward(self, v_tokens: torch.Tensor, t_tokens: torch.Tensor, t_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        v = self.v_pool(v_tokens, None)
        t = self.t_pool(t_tokens, t_mask)
        return self.mlp(torch.cat([v, t], dim=-1))

class RewardHead(nn.Module):
    def __init__(self, latent: int = 256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.LayerNorm(latent),
            nn.Linear(latent, latent // 2),
            nn.GELU(),
            nn.Linear(latent // 2, 1),
        )
    def forward(self, h: torch.Tensor) -> torch.Tensor:
        return self.mlp(h).squeeze(-1)

class ObsDecoder(nn.Module):
    def __init__(self, latent: int = 256, obs_dim: int = 1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(latent * 2),
            nn.Linear(latent * 2, obs_dim),
            nn.GELU(),
            nn.Linear(obs_dim, obs_dim),
        )
    def forward(self, h: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        return self.net(torch.cat([h, z], dim=-1))

# --- Episodic Memory ---

class EpisodicMemory(nn.Module):
    def __init__(self, dim: int = 1024, slots: int = 1024, temperature: float = 0.1, ema: float = 0.1):
        super().__init__()
        self.dim, self.slots = dim, slots
        self.temperature, self.ema = temperature, ema
        self.register_buffer("keys", F.normalize(torch.randn(slots, dim), dim=-1))
        self.register_buffer("values", torch.zeros(slots, dim))
        self.register_buffer("age", torch.zeros(slots, dtype=torch.long))
        self.register_buffer("ptr", torch.zeros((), dtype=torch.long))
    @torch.no_grad()
    def write(self, k: torch.Tensor, v: torch.Tensor, strategy: str = "nearest"):
        k = k.unsqueeze(0) if k.dim() == 1 else k
        v = v.unsqueeze(0) if v.dim() == 1 else v
        k = maybe_norm(k)
        if strategy == "ring":
            for i in range(k.size(0)):
                idx = int(self.ptr.item() % self.slots)
                self.keys[idx] = k[i]
                self.values[idx] = v[i]
                self.age[idx] = 0
                self.ptr += 1
        else:
            sims = (maybe_norm(self.keys) @ k.T)
            idxs = sims.argmax(dim=0)
            for i, idx in enumerate(idxs.tolist()):
                self.keys[idx] = F.normalize((1 - self.ema) * self.keys[idx] + self.ema * k[i], dim=-1)
                self.values[idx] = (1 - self.ema) * self.values[idx] + self.ema * v[i]
                self.age[idx] = 0
        self.age += 1

    def read(self, q: torch.Tensor, topk: int = 0) -> torch.Tensor:
        single = q.dim() == 1
        q = q.unsqueeze(0) if single else q
        logits = (maybe_norm(q) @ maybe_norm(self.keys).T) / self.temperature
        if topk and topk < self.slots:
            vals, idxs = logits.topk(topk, dim=-1)
            w = F.softmax(vals, dim=-1)
            out = (w.unsqueeze(-1) * self.values[idxs]).sum(1)
        else:
            w = F.softmax(logits, dim=-1)
            out = w @ self.values
        return out.squeeze(0) if single else out

# --- RSSM ---

class LatentRSSM(nn.Module):
    def __init__(self, latent: int = 256, action_dim: int = 32, obs_dim: int = 1024):
        super().__init__()
        self.latent, self.action_dim, self.obs_dim = latent, action_dim, obs_dim
        self.gru = nn.GRUCell(latent + action_dim, latent)
        self.prior = nn.Sequential(nn.LayerNorm(latent), nn.Linear(latent, 2 * latent))
        self.post = nn.Sequential(nn.LayerNorm(latent + obs_dim), nn.Linear(latent + obs_dim, 2 * latent))
    def init_state(self, batch: int, device=None):
        device = device or default_device()
        return torch.zeros(batch, self.latent, device=device), torch.zeros(batch, self.latent, device=device)
    def step(self, h, z, action, obs_embed=None, rng=None):
        x = torch.cat([z, action], dim=-1)
        h = self.gru(x, h)
        prior_stats = self.prior(h)
        if obs_embed is not None:
            post_stats = self.post(torch.cat([h, obs_embed], dim=-1))
            z = reparam_gauss(post_stats, rng=rng)
        else:
            post_stats = None
            z = reparam_gauss(prior_stats, rng=rng)
        return h, z, prior_stats, post_stats

# --- LLM stub + Policy/Reward ---

class DummyLLM:
    def __init__(self, action_vocab: List[str]):
        self.action_vocab = action_vocab
    def propose_actions(self, goal_desc: str, max_steps: int = 5) -> List[str]:
        random.seed(hash(goal_desc) % 10000)
        return [self.action_vocab[i % len(self.action_vocab)] for i in range(max_steps)]
    def revise_actions(self, plan: List[str], max_steps: int = 5) -> List[str]:
        return (plan + plan)[1 : 1 + max_steps]

class TextualPolicyToAction:
    def __init__(self, action_dim: int = 32, actions: Dict[str, int] | None = None):
        self.actions = actions or {n: i for i, n in enumerate(
            ["move_left", "move_right", "move_up", "move_down", "pick", "place", "wait", "scan"]
        )}
        self.n = max(self.actions.values()) + 1
        self.emb = nn.Embedding(self.n, action_dim)
        with torch.no_grad():
            nn.init.normal_(self.emb.weight, std=0.2)
    def __call__(self, action_text: str, device: torch.device, action_dim: int) -> torch.Tensor:
        idx = torch.tensor(self.actions.get(action_text, 0), device=device, dtype=torch.long)
        vec = self.emb(idx)
        if vec.size(-1) != action_dim:
            W = torch.empty(vec.size(-1), action_dim, device=device)
            nn.init.kaiming_uniform_(W, a=math.sqrt(5))
            vec = vec @ W
        return vec

class SimpleReward:
    def __init__(self, latent: int = 256):
        self.goal = torch.zeros(latent)
    def set_goal(self, g: torch.Tensor):
        self.goal = g.detach()
    def __call__(self, h: torch.Tensor) -> torch.Tensor:
        g = self.goal.to(h.device).expand_as(h)
        return -((h - g) ** 2).sum(dim=-1)

# --- Planner ---

class HybridPlanner(nn.Module):
    def __init__(self, llm, rssm, action_fn, reward_fn, validate_fn=None, max_depth: int = 5):
        super().__init__()
        self.llm, self.rssm = llm, rssm
        self.action_fn, self.reward_fn = action_fn, reward_fn
        self.validate_fn, self.max_depth = validate_fn, max_depth
    @torch.no_grad()
    def plan(self, goal_desc, init_h, init_z, batch: int = 1):
        steps = self.llm.propose_actions(goal_desc, max_steps=self.max_depth)
        if self.validate_fn is not None and not self.validate_fn(steps):
            steps = self.llm.revise_actions(plan=steps, max_steps=self.max_depth)
        device = init_h.device
        h, z = init_h, init_z
        traj, total_reward = [], torch.zeros((), device=device)
        for s in steps[: self.max_depth]:
            act = self.action_fn(s, device, self.rssm.action_dim)
            act = act.unsqueeze(0) if act.dim() == 1 else act
            if act.size(0) != batch:
                act = act.expand(batch, -1)
            h, z, _, _ = self.rssm.step(h, z, act, obs_embed=None)
            r_t = self.reward_fn(h)
            total_reward = total_reward + r_t.mean()
            traj.append((h.clone(), act.clone(), r_t.clone()))
        return traj, total_reward

# --- Synthetic env + train/demo ---

def synthetic_env_step(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
    B, S = state.shape
    A = action.shape[-1]
    W_s = torch.eye(S, device=state.device) * 0.95
    W_a = torch.randn(A, S, device=state.device) * 0.05
    noise = torch.randn_like(state) * 0.01
    return state @ W_s.T + action @ W_a + noise

def build_obs_from_state(state: torch.Tensor, obs_dim: int) -> torch.Tensor:
    S = state.size(-1)
    W = torch.randn(S, obs_dim, device=state.device) / (S ** 0.5)
    return state @ W

def train_loop(steps: int, batch: int, device: torch.device):
    torch.manual_seed(42)
    E, latent, action_dim, obs_dim = 1024, 256, 32, 1024
    venc, tenc = MockVisionEncoder(E).to(device), MockTextEncoder(embed_dim=E).to(device)
    fuse = FusionObsEncoder(E, obs_dim).to(device)
    rssm = LatentRSSM(latent, action_dim, obs_dim).to(device)
    dec = ObsDecoder(latent, obs_dim).to(device)
    mem = EpisodicMemory(obs_dim, slots=512).to(device)

    params = list(rssm.parameters()) + list(dec.parameters()) + list(fuse.parameters())
    opt = optim.AdamW(params, lr=3e-4, weight_decay=1e-4)
    state = torch.zeros(batch, latent, device=device)
    free_nats, kl_scale = 1.0, 1.0
    pbar = trange(steps, desc="train")
    h, z = rssm.init_state(batch, device=device)

    for _ in pbar:
        imgs = torch.randn(batch, 3, 128, 128, device=device)
        vtoks = venc(imgs)
        ttoks, tmask = tenc(["demo text"] * batch)
        obs_embed = fuse(vtoks, ttoks, tmask)
        target_obs = build_obs_from_state(state, obs_dim)
        action = torch.tanh(torch.randn(batch, action_dim, device=device))
        h, z, prior_stats, post_stats = rssm.step(h, z, action, obs_embed=obs_embed)
        pred_obs = dec(h, z)
        recon = F.mse_loss(pred_obs, target_obs)
        kl = kl_gauss(post_stats, prior_stats, reduce=False).mean()
        kl_free = torch.clamp(kl - free_nats, min=0.0)
        loss = recon + kl_scale * kl_free
        opt.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(params, 1.0); opt.step()
        state = synthetic_env_step(state, action)

        # Episodic memory: write current fused obs and target obs; read for auxiliary retrieval
        mem.write(obs_embed.detach(), target_obs.detach())
        retrieved = mem.read(obs_embed, topk=8)
        aux = F.mse_loss(retrieved, target_obs.detach())

        total = (recon + kl_scale * kl_free + 0.1 * aux).detach().item()
        pbar.set_postfix(loss=total, recon=recon.detach().item(), kl=kl.detach().item(), aux=aux.detach().item())

        opt.zero_grad()
        (recon + kl_scale * kl_free + 0.1 * aux).backward()
        torch.nn.utils.clip_grad_norm_(params, 1.0)
        opt.step()

    return {
        "vision": venc,
        "text": tenc,
        "fuse": fuse,
        "rssm": rssm,
        "decoder": dec,
        "memory": mem,
    }

# --- Demo planning ---

@torch.no_grad()
def demo_planning(artifacts: dict, goal_desc: str, batch: int, plan_depth: int, device: torch.device):
    rssm: LatentRSSM = artifacts["rssm"]
    mem: EpisodicMemory = artifacts["memory"]

    # Build LLM, action mapper, reward
    default_actions = ["move_left", "move_right", "move_up", "move_down", "pick", "place", "wait", "scan"]
    llm = DummyLLM(action_vocab=default_actions)
    action_fn = TextualPolicyToAction(action_dim=rssm.action_dim)
    reward_fn = SimpleReward(latent=rssm.latent)
    planner = HybridPlanner(llm, rssm, action_fn, reward_fn, validate_fn=None, max_depth=plan_depth).to(device)

    # Initialize latent state
    h, z = rssm.init_state(batch, device=device)

    # Optional: set a simple goal as the running average of memory values (if any)
    if mem.values.numel() > 0 and mem.age.max() > 0:
        # Use a small top-k read from a random query to seed a goal
        probe = torch.randn(batch, mem.values.size(-1), device=device)
        goal_vec = mem.read(probe, topk=8).mean(0)
        # Project to latent size with a small random map (no grads)
        P = torch.randn(mem.values.size(-1), rssm.latent, device=device) / (mem.values.size(-1) ** 0.5)
        reward_fn.set_goal(goal_vec @ P)
    # Else default goal is zero (already set)

    traj, total_reward = planner.plan(goal_desc, h, z, batch=batch)

    print("Plan summary")
    print(f"- Steps: {len(traj)}")
    print(f"- Total reward (mean over batch per step): {total_reward.item():.4f}")
    # Show last hidden norm for a quick sanity metric
    last_h, _, _ = traj[-1]
    print(f"- Last hidden L2 norm (avg over batch): {last_h.norm(dim=-1).mean().item():.3f}")

# --- Entry point ---

def select_device(name: str) -> torch.device:
    if name == "auto":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.device(name)

if __name__ == "__main__":
    import argparse
    import random

    parser = argparse.ArgumentParser(description="UnifiedAI minimal end-to-end demo")
    parser.add_argument("--train-steps", type=int, default=200, help="Training iterations")
    parser.add_argument("--batch", type=int, default=16, help="Batch size")
    parser.add_argument("--plan-depth", type=int, default=5, help="Planner horizon")
    parser.add_argument("--device", type=str, default="auto", help="cpu | cuda | auto")
    parser.add_argument("--goal", type=str, default="navigate to target and stabilize", help="High-level goal text")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    args = parser.parse_args()

    # Seeding
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = select_device(args.device)
    print(f"Using device: {device}")

    artifacts = train_loop(steps=args.train_steps, batch=args.batch, device=device)
    demo_planning(artifacts, goal_desc=args.goal, batch=args.batch, plan_depth=args.plan_depth, device=device)

# --- Checkpointing and reproducibility ---

def set_seed(seed: int) -> None:
    random.seed(seed)
    np_rand = None
    try:
        import numpy as _np
        np_rand = _np
    except Exception:
        np_rand = None
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    if np_rand is not None:
        np_rand.random.seed(seed)

def artifact_state(artifacts: dict) -> dict:
    rssm: LatentRSSM = artifacts["rssm"]
    mem: EpisodicMemory = artifacts["memory"]
    venc = artifacts["vision"]
    tenc = artifacts["text"]
    fuse: nn.Module = artifacts["fuse"]
    dec: nn.Module = artifacts["decoder"]

    # Try to infer encoder embed size E, fall back to rssm.obs_dim
    E = getattr(tenc, "embed_dim", getattr(venc, "embed_dim", rssm.obs_dim))
    meta = {
        "E": int(E),
        "latent": int(rssm.latent),
        "action_dim": int(rssm.action_dim),
        "obs_dim": int(rssm.obs_dim),
        "memory": {
            "slots": int(mem.slots),
            "temperature": float(mem.temperature),
            "ema": float(mem.ema),
        },
        "versions": {
            "torch": torch.__version__,
            "unifiedai": "0.1.0",
        },
    }
    return {
        "meta": meta,
        "state": {
            "vision": venc.state_dict(),
            "text": tenc.state_dict(),
            "fuse": fuse.state_dict(),
            "rssm": rssm.state_dict(),
            "decoder": dec.state_dict(),
            "memory": {
                "keys": mem.keys,
                "values": mem.values,
                "age": mem.age,
            },
        },
    }

def save_artifacts(artifacts: dict, path: str) -> None:
    pkg = artifact_state(artifacts)
    torch.save(pkg, path)

def load_artifacts(path: str, device: torch.device = None) -> dict:
    device = device or select_device("auto")
    pkg = torch.load(path, map_location=device)
    meta = pkg["meta"]
    E = meta["E"]; latent = meta["latent"]; action_dim = meta["action_dim"]; obs_dim = meta["obs_dim"]
    mem_cfg = meta["memory"]

    # Rebuild modules
    venc = MockVisionEncoder(E).to(device)
    tenc = MockTextEncoder(embed_dim=E).to(device)
    fuse = FusionObsEncoder(E, obs_dim).to(device)
    rssm = LatentRSSM(latent, action_dim, obs_dim).to(device)
    dec = ObsDecoder(latent, obs_dim).to(device)
    mem = EpisodicMemory(obs_dim, slots=mem_cfg["slots"], temperature=mem_cfg["temperature"], ema=mem_cfg["ema"]).to(device)

    # Load weights
    venc.load_state_dict(pkg["state"]["vision"])
    tenc.load_state_dict(pkg["state"]["text"])
    fuse.load_state_dict(pkg["state"]["fuse"])
    rssm.load_state_dict(pkg["state"]["rssm"])
    dec.load_state_dict(pkg["state"]["decoder"])

    with torch.no_grad():
        mem.keys.copy_(pkg["state"]["memory"]["keys"].to(device))
        mem.values.copy_(pkg["state"]["memory"]["values"].to(device))
        mem.age.copy_(pkg["state"]["memory"]["age"].to(device))

    return {
        "vision": venc,
        "text": tenc,
        "fuse": fuse,
        "rssm": rssm,
        "decoder": dec,
        "memory": mem,
    }

# --- Smoke tests (optional, call manually) ---

def run_smoke_tests(device: torch.device = None) -> None:
    device = device or select_device("auto")
    print("[TEST] device:", device)

    # Shapes and forward passes
    E, latent, action_dim, obs_dim = 256, 128, 16, 512
    venc = MockVisionEncoder(E).to(device)
    tenc = MockTextEncoder(embed_dim=E).to(device)
    fuse = FusionObsEncoder(E, obs_dim).to(device)
    rssm = LatentRSSM(latent, action_dim, obs_dim).to(device)
    dec = ObsDecoder(latent, obs_dim).to(device)
    mem = EpisodicMemory(obs_dim, slots=64).to(device)

    B = 4
    imgs = torch.randn(B, 3, 128, 128, device=device)
    vtoks = venc(imgs)
    texts = ["hello world"] * B
    ttoks, tmask = tenc(texts)
    fused = fuse(vtoks, ttoks, tmask)
    assert fused.shape == (B, obs_dim), "[FAIL] Fusion output shape mismatch"
    print("[PASS] fusion")

    h, z = rssm.init_state(B, device=device)
    act = torch.randn(B, action_dim, device=device)
    h2, z2, prior, post = rssm.step(h, z, act, obs_embed=fused)
    pred = dec(h2, z2)
    assert pred.shape == (B, obs_dim), "[FAIL] Decoder output shape mismatch"
    print("[PASS] rssm+decoder")

    mem.write(fused, pred.detach())
    out = mem.read(fused, topk=8)
    assert out.shape == (B, obs_dim), "[FAIL] Memory read shape mismatch"
    print("[PASS] memory")

    # Planner roll
    default_actions = ["move_left", "move_right", "move_up", "move_down", "pick", "place", "wait", "scan"]
    llm = DummyLLM(default_actions)
    action_fn = TextualPolicyToAction(action_dim=action_dim)
    reward_fn = SimpleReward(latent=latent)
    planner = HybridPlanner(llm, rssm, action_fn, reward_fn, max_depth=3).to(device)
    traj, R = planner.plan("test goal", h2, z2, batch=B)
    assert len(traj) == 3 and isinstance(R, torch.Tensor), "[FAIL] Planner trajectory"
    print("[PASS] planner")
    print("[TEST] All smoke tests passed.")

# --- Public API ---

__version__ = "0.1.0"
__all__ = [
    "MockVisionEncoder",
    "MockTextEncoder",
    "FusionObsEncoder",
    "ObsDecoder",
    "EpisodicMemory",
    "LatentRSSM",
    "DummyLLM",
    "TextualPolicyToAction",
    "SimpleReward",
    "HybridPlanner",
    "synthetic_env_step",
    "build_obs_from_state",
    "train_loop",
    "demo_planning",
    "select_device",
    "set_seed",
    "artifact_state",
    "save_artifacts",
    "load_artifacts",
    "run_smoke_tests",
]