<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/unified_ai_core_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# unified_ai/core.py
from __future__ import annotations
from typing import List, Optional, Tuple, Callable, Union

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# Optional deps; guarded for mock paths
try:
    import timm
except Exception:
    timm = None

try:
    from sentencepiece import SentencePieceProcessor
except Exception:
    SentencePieceProcessor = None


# ===== Utils =====

def default_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def reparam_gauss(stats: torch.Tensor, rng: Optional[torch.Generator] = None) -> torch.Tensor:
    mu, logvar = torch.chunk(stats, 2, dim=-1)
    std = (0.5 * logvar).exp()
    eps = torch.randn_like(std, generator=rng)
    return mu + eps * std

def kl_gauss(post: torch.Tensor, prior: torch.Tensor, reduce: bool = True) -> torch.Tensor:
    mu_p, lv_p = torch.chunk(post, 2, dim=-1)
    mu_q, lv_q = torch.chunk(prior, 2, dim=-1)
    kl = 0.5 * (lv_q - lv_p + (lv_p.exp() + (mu_p - mu_q).pow(2)) / lv_q.exp() - 1.0).sum(-1)
    return kl.mean() if reduce else kl

def maybe_norm(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    return x / (x.norm(dim=-1, keepdim=True) + eps)

def pair_rotate(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    x1, x2 = x[..., ::2], x[..., 1::2]
    y1 = x1 * cos - x2 * sin
    y2 = x1 * sin + x2 * cos
    return torch.stack([y1, y2], dim=-1).flatten(-2)

def build_rope_cache(seq_len: int, dim: int, base: float = 10000.0, device=None, dtype=None):
    half = dim // 2
    idx = torch.arange(half, device=device, dtype=dtype)
    freqs = 1.0 / (base ** (idx / half))
    t = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(1)
    angles = t * freqs.unsqueeze(0)
    cos = torch.cos(angles).repeat_interleave(2, dim=-1)
    sin = torch.sin(angles).repeat_interleave(2, dim=-1)
    return cos, sin


# ===== Perception =====

class VisionEncoder(nn.Module):
    """ViT/Swin backbone with optional freezing and projection."""
    def __init__(self, model_name: str = "vit_base_patch16_224", embed_dim: int = 1024, freeze_blocks: int = 6, global_pool: str = ""):
        super().__init__()
        assert timm is not None, "timm is required for VisionEncoder"
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0, global_pool=global_pool)
        if hasattr(self.backbone, "patch_embed"):
            for p in self.backbone.patch_embed.parameters():
                p.requires_grad = False
        if hasattr(self.backbone, "blocks"):
            for i, blk in enumerate(self.backbone.blocks):
                req = i >= freeze_blocks
                for p in blk.parameters():
                    p.requires_grad = req
        feat_dim = getattr(self.backbone, "num_features", getattr(self.backbone, "embed_dim"))
        self.proj = nn.Linear(feat_dim, embed_dim)

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        x = self.backbone.forward_features(img)
        if x.dim() == 2:
            x = x.unsqueeze(1)
        return self.proj(x)  # (B, N, E)


class MockVisionEncoder(nn.Module):
    """Lightweight, dependency-free visual token generator."""
    def __init__(self, embed_dim: int = 1024, tokens: int = 16):
        super().__init__()
        self.tokens = tokens
        self.mlp = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.GELU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.GELU(),
            nn.AdaptiveAvgPool2d((tokens, 1)),
        )
        self.proj = nn.Linear(64, embed_dim)

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        x = self.mlp(img)           # (B, 64, T, 1)
        x = x.squeeze(-1).transpose(1, 2)  # (B, T, 64)
        return self.proj(x)         # (B, T, E)


class TextEncoder(nn.Module):
    """SentencePiece tokenization + embedding + RoPE."""
    def __init__(self, vocab_path: str, embed_dim: int = 1024, max_len: int = 512, rope_base: float = 10000.0):
        super().__init__()
        assert SentencePieceProcessor is not None, "sentencepiece is required for TextEncoder"
        self.tok = SentencePieceProcessor(vocab_path)
        self.vocab_size = self.tok.get_piece_size()
        self.embed = nn.Embedding(self.vocab_size, embed_dim)
        self.max_len = max_len
        self.rope_base = rope_base

    def tokenize(self, texts: Union[str, List[str]]):
        if isinstance(texts, str):
            texts = [texts]
        ids = [self.tok.encode(t, out_type=int)[: self.max_len] for t in texts]
        max_t = max(len(x) for x in ids) if ids else 1
        pad_id = self.tok.pad_id() if hasattr(self.tok, "pad_id") and self.tok.pad_id() >= 0 else 0
        toks = torch.full((len(ids), max_t), pad_id, dtype=torch.long)
        mask = torch.zeros((len(ids), max_t), dtype=torch.bool)
        for i, seq in enumerate(ids):
            L = len(seq)
            toks[i, :L] = torch.tensor(seq, dtype=torch.long)
            mask[i, :L] = True
        return toks, mask

    def forward(self, texts: Union[str, List[str]]):
        toks, mask = self.tokenize(texts)
        device = self.embed.weight.device
        toks, mask = toks.to(device), mask.to(device)
        x = self.embed(toks)
        T, E = x.shape[1], x.shape[2]
        pad = False
        if E % 2 != 0:
            x = F.pad(x, (0, 1))
            E += 1
            pad = True
        cos, sin = build_rope_cache(T, E, base=self.rope_base, device=device, dtype=x.dtype)
        x = pair_rotate(x, cos[None, ...], sin[None, ...])
        if pad:
            x = x[..., : self.embed.embedding_dim]
        return x, mask  # (B, T, E), (B, T)


class MockTextEncoder(nn.Module):
    """Dependency-free text embedding with sinusoidal positions."""
    def __init__(self, vocab_size: int = 32000, embed_dim: int = 1024, max_len: int = 128):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.max_len = max_len

    def forward(self, texts: Union[str, List[str]]):
        if isinstance(texts, str):
            texts = [texts]
        toks = [[(ord(c) % 255) + 1 for c in t][: self.max_len] for t in texts]
        max_t = max(len(x) for x in toks) if toks else 1
        pad_id = 0
        ids = torch.full((len(toks), max_t), pad_id, dtype=torch.long)
        mask = torch.zeros((len(toks), max_t), dtype=torch.bool)
        for i, seq in enumerate(toks):
            L = len(seq)
            ids[i, :L] = torch.tensor(seq, dtype=torch.long)
            mask[i, :L] = True
        x = self.embed(ids.to(self.embed.weight.device))
        pos = torch.arange(max_t, device=x.device).float()
        pe = torch.stack([torch.sin(pos / 10000**(2*k/x.size(-1))) if k % 2 == 0 else torch.cos(pos / 10000**(2*(k-1)/x.size(-1))) for k in range(x.size(-1))], dim=-1)
        x = x + pe.unsqueeze(0)
        return x, mask.to(x.device)


# ===== Fusion and heads =====

class TokenPool(nn.Module):
    """Simple attention pooling over tokens."""
    def __init__(self, dim: int):
        super().__init__()
        self.q = nn.Parameter(torch.randn(dim))

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # x: (B, T, E), mask: (B, T)
        q = self.q[None, None, :]            # (1,1,E)
        attn = (x * q).sum(-1)               # (B, T)
        if mask is not None:
            attn = attn.masked_fill(~mask, -1e9)
        w = attn.softmax(-1)
        return (w.unsqueeze(-1) * x).sum(1)  # (B, E)

class FusionObsEncoder(nn.Module):
    """Fuse vision and text into a fixed obs embedding."""
    def __init__(self, embed_dim: int = 1024, obs_dim: int = 1024):
        super().__init__()
        self.v_pool = TokenPool(embed_dim)
        self.t_pool = TokenPool(embed_dim)
        self.mlp = nn.Sequential(
            nn.LayerNorm(embed_dim * 2),
            nn.Linear(embed_dim * 2, obs_dim),
            nn.GELU(),
            nn.Linear(obs_dim, obs_dim),
        )

    def forward(self, v_tokens: torch.Tensor, t_tokens: torch.Tensor, t_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        v = self.v_pool(v_tokens, None)
        t = self.t_pool(t_tokens, t_mask)
        x = torch.cat([v, t], dim=-1)
        return self.mlp(x)  # (B, obs_dim)

class RewardHead(nn.Module):
    """Score from deterministic state h."""
    def __init__(self, latent: int = 256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.LayerNorm(latent),
            nn.Linear(latent, latent // 2),
            nn.GELU(),
            nn.Linear(latent // 2, 1),
        )

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        return self.mlp(h).squeeze(-1)  # (B,)

class ObsDecoder(nn.Module):
    """Decode obs embedding from h and z."""
    def __init__(self, latent: int = 256, obs_dim: int = 1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(latent * 2),
            nn.Linear(latent * 2, obs_dim),
            nn.GELU(),
            nn.Linear(obs_dim, obs_dim),
        )

    def forward(self, h: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        return self.net(torch.cat([h, z], dim=-1))  # (B, obs_dim)


# ===== Episodic memory =====

class EpisodicMemory(nn.Module):
    """Cosine attention reads; EMA/ring writes (buffers, no autograd)."""
    def __init__(self, dim: int = 1024, slots: int = 1024, temperature: float = 0.1, ema: float = 0.1):
        super().__init__()
        self.dim, self.slots = dim, slots
        self.temperature, self.ema = temperature, ema
        self.register_buffer("keys", F.normalize(torch.randn(slots, dim), dim=-1))
        self.register_buffer("values", torch.zeros(slots, dim))
        self.register_buffer("age", torch.zeros(slots, dtype=torch.long))
        self.register_buffer("ptr", torch.zeros((), dtype=torch.long))

    @torch.no_grad()
    def write(self, k: torch.Tensor, v: torch.Tensor, strategy: str = "nearest"):
        k = k.unsqueeze(0) if k.dim() == 1 else k
        v = v.unsqueeze(0) if v.dim() == 1 else v
        k = maybe_norm(k)
        if strategy == "ring":
            for i in range(k.size(0)):
                idx = int(self.ptr.item() % self.slots)
                self.keys[idx] = k[i]
                self.values[idx] = v[i]
                self.age[idx] = 0
                self.ptr += 1
        else:
            sims = (maybe_norm(self.keys) @ k.T)  # (S,B)
            idxs = sims.argmax(dim=0)
            for i, idx in enumerate(idxs.tolist()):
                self.keys[idx] = F.normalize((1 - self.ema) * self.keys[idx] + self.ema * k[i], dim=-1)
                self.values[idx] = (1 - self.ema) * self.values[idx] + self.ema * v[i]
                self.age[idx] = 0
        self.age += 1

    def read(self, q: torch.Tensor, topk: int = 0) -> torch.Tensor:
        single = q.dim() == 1
        q = q.unsqueeze(0) if single else q
        logits = (maybe_norm(q) @ maybe_norm(self.keys).T) / self.temperature
        if topk and topk < self.slots:
            vals, idxs = logits.topk(topk, dim=-1)
            w = F.softmax(vals, dim=-1)
            out = (w.unsqueeze(-1) * self.values[idxs]).sum(1)
        else:
            w = F.softmax(logits, dim=-1)
            out = w @ self.values
        return out.squeeze(0) if single else out


# ===== RSSM =====

class LatentRSSM(nn.Module):
    """Dreamer-style RSSM with deterministic h and stochastic z."""
    def __init__(self, latent: int = 256, action_dim: int = 32, obs_dim: int = 1024):
        super().__init__()
        self.latent, self.action_dim, self.obs_dim = latent, action_dim, obs_dim
        self.gru = nn.GRUCell(latent + action_dim, latent)
        self.prior = nn.Sequential(nn.LayerNorm(latent), nn.Linear(latent, 2 * latent))
        self.post = nn.Sequential(nn.LayerNorm(latent + obs_dim), nn.Linear(latent + obs_dim, 2 * latent))

    def init_state(self, batch: int, device=None):
        device = device or default_device()
        h = torch.zeros(batch, self.latent, device=device)
        z = torch.zeros(batch, self.latent, device=device)
        return h, z

    def step(self, h: torch.Tensor, z: torch.Tensor, action: torch.Tensor, obs_embed: Optional[torch.Tensor] = None, rng: Optional[torch.Generator] = None):
        x = torch.cat([z, action], dim=-1)
        h = self.gru(x, h)
        prior_stats = self.prior(h)
        if obs_embed is not None:
            post_stats = self.post(torch.cat([h, obs_embed], dim=-1))
            z = reparam_gauss(post_stats, rng=rng)
        else:
            post_stats = None
            z = reparam_gauss(prior_stats, rng=rng)
        return h, z, prior_stats, post_stats


# ===== Planner =====

class HybridPlanner(nn.Module):
    """LLM-guided plan -> optional validation -> RSSM rollout prior."""
    def __init__(
        self,
        llm,
        rssm: LatentRSSM,
        action_fn: Callable[[str, torch.device, int], torch.Tensor],
        reward_fn: Callable[[torch.Tensor], torch.Tensor],
        validate_fn: Optional[Callable[[List[str]], bool]] = None,
        max_depth: int = 5,
    ):
        super().__init__()
        self.llm, self.rssm = llm, rssm
        self.action_fn, self.reward_fn = action_fn, reward_fn
        self.validate_fn, self.max_depth = validate_fn, max_depth

    @torch.no_grad()
    def plan(self, goal_desc: str, init_h: torch.Tensor, init_z: torch.Tensor, batch: int = 1):
        steps: List[str] = self.llm.propose_actions(goal_desc, max_steps=self.max_depth)
        if self.validate_fn is not None and not self.validate_fn(steps):
            steps = self.llm.revise_actions(plan=steps, max_steps=self.max_depth)

        device = init_h.device
        h, z = init_h, init_z
        traj, total_reward = [], torch.zeros((), device=device)
        for t, s in enumerate(steps[: self.max_depth]):
            act = self.action_fn(s, device, self.rssm.action_dim)
            act = act.unsqueeze(0) if act.dim() == 1 else act
            if act.size(0) != batch:
                act = act.expand(batch, -1)
            h, z, prior_stats, _ = self.rssm.step(h, z, act, obs_embed=None)
            r_t = self.reward_fn(h)  # (B,)
            total_reward = total_reward + r_t.mean()
            traj.append((h.clone(), act.clone(), r_t.clone()))
        return traj, total_reward