<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/unified_ai_core_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
import timm
from sentencepiece import SentencePieceProcessor

# ---------- 1. Perception ----------------------------------------------------
class VisionEncoder(nn.Module):
    def __init__(self, embed_dim=1024):
        super().__init__()
        self.backbone = timm.create_model(
            'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k',
            pretrained=True, num_classes=0
        )
        self.proj = nn.Linear(self.backbone.embed_dim, embed_dim)

    def forward(self, img):
        return self.proj(self.backbone(img))  # (B, N_tokens, embed_dim)

class TextEncoder(nn.Module):
    def __init__(self, vocab_path, embed_dim=1024):
        super().__init__()
        self.tok = SentencePieceProcessor(model_file=vocab_path)
        self.embed = nn.Embedding(self.tok.get_piece_size(), embed_dim)

    def forward(self, txt):
        ids = torch.tensor(self.tok.encode(txt, out_type=int), device=self.embed.weight.device)
        return self.embed(ids)  # (T, embed_dim)

# ---------- 2. Memory --------------------------------------------------------
class EpisodicMemory(nn.Module):
    def __init__(self, dim=1024, slots=4096):
        super().__init__()
        self.keys   = nn.Parameter(torch.randn(slots, dim))
        self.values = nn.Parameter(torch.randn(slots, dim))

    def write(self, k, v):
        idx = torch.argmax((self.keys @ k.T).diag())  # overwrite closest slot
        self.keys.data[idx]   = k.detach()
        self.values.data[idx] = v.detach()

    def read(self, q):  # q: (B, dim)
        scores = q @ self.keys.T
        w = F.softmax(scores, dim=-1)
        return w @ self.values  # (B, dim)

# ---------- 3. Latent World Model -------------------------------------------
class LatentRSSM(nn.Module):
    def __init__(self, latent=256, action_dim=32, obs_dim=1024):
        super().__init__()
        self.rnn   = nn.GRU(latent + action_dim, latent, batch_first=True)
        self.post  = nn.Linear(latent + obs_dim, 2 * latent)
        self.prior = nn.Linear(latent, 2 * latent)

    def forward(self, prev_state, action, obs_embed):
        rnn_input = torch.cat([prev_state, action], -1)
        _, h = self.rnn(rnn_input.unsqueeze(0))
        post_stats  = self.post(torch.cat([h.squeeze(0), obs_embed], -1))
        prior_stats = self.prior(h.squeeze(0))
        return post_stats, prior_stats, h.squeeze(0)

# ---------- 4. Planner -------------------------------------------------------
class HybridPlanner(nn.Module):
    def __init__(self, llm, rssm, max_depth=5):
        super().__init__()
        self.llm = llm
        self.rssm = rssm
        self.max_depth = max_depth

    @torch.no_grad()
    def plan(self, goal_desc, init_state):
        prelim_plan = self.llm.chain_of_thought(goal_desc)
        if not pddl_validator(prelim_plan):
            prelim_plan = self.llm.fix(plan=prelim_plan)
        return self.imagine(prelim_plan, init_state)

    def imagine(self, plan, state):
        traj, total_reward = [], 0.0
        for t in range(self.max_depth):
            act = textual_policy_to_action(plan[t])
            _, prior_stats, state = self.rssm(state, act, obs_embed=None)
            state = reparam(prior_stats)
            r_t = reward_head(state)
            traj.append((state, act, r_t))
            total_reward += r_t
        return traj, total_reward