<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/Complete_Unified_AI_System_(Single_File).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
import timm
from sentencepiece import SentencePieceProcessor

# ---------- 1. Perception ----------------------------------------------------
class VisionEncoder(nn.Module):
    def __init__(self, embed_dim=1024):
        super().__init__()
        self.backbone = timm.create_model(
            'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k',
            pretrained=True, num_classes=0
        )
        self.proj = nn.Linear(self.backbone.embed_dim, embed_dim)

    def forward(self, img):
        return self.proj(self.backbone(img))

class TextEncoder(nn.Module):
    def __init__(self, vocab_path, embed_dim=1024):
        super().__init__()
        self.tok = SentencePieceProcessor(model_file=vocab_path)
        self.embed = nn.Embedding(self.tok.get_piece_size(), embed_dim)

    def forward(self, txt):
        ids = torch.tensor(self.tok.encode(txt, out_type=int), device=self.embed.weight.device)
        return self.embed(ids)

# ---------- 2. Memory --------------------------------------------------------
class EpisodicMemory(nn.Module):
    def __init__(self, dim=1024, slots=4096):
        super().__init__()
        self.keys   = nn.Parameter(torch.randn(slots, dim))
        self.values = nn.Parameter(torch.randn(slots, dim))

    def write(self, k, v):
        idx = torch.argmax((self.keys @ k.T).diag())
        self.keys.data[idx]   = k.detach()
        self.values.data[idx] = v.detach()

    def read(self, q):
        scores = q @ self.keys.T
        w = F.softmax(scores, dim=-1)
        return w @ self.values

# ---------- 3. World Model ---------------------------------------------------
class LatentRSSM(nn.Module):
    def __init__(self, latent=256, action_dim=32, obs_dim=1024):
        super().__init__()
        self.rnn   = nn.GRU(latent + action_dim, latent, batch_first=True)
        self.post  = nn.Linear(latent + obs_dim, 2 * latent)
        self.prior = nn.Linear(latent, 2 * latent)

    def forward(self, prev_state, action, obs_embed):
        rnn_input = torch.cat([prev_state, action], -1)
        _, h = self.rnn(rnn_input.unsqueeze(0))
        h = h.squeeze(0)
        post_stats  = self.post(torch.cat([h, obs_embed], -1))
        prior_stats = self.prior(h)
        return post_stats, prior_stats, h

# ---------- 4. Planner -------------------------------------------------------
class LLMStub:
    def chain_of_thought(self, goal_desc):
        return ["move_left", "jump", "move_right", "duck", "jump"]

    def fix(self, plan):
        return ["move_left", "move_right", "duck", "jump", "duck"]

def reparam(stats):
    mu, log_std = stats.chunk(2, dim=-1)
    std = torch.exp(log_std)
    eps = torch.randn_like(std)
    return mu + eps * std

def reward_head(latent):
    return torch.sigmoid(torch.mean(latent))

def textual_policy_to_action(text_cmd):
    vocab = {"move_left": 0, "move_right": 1, "jump": 2, "duck": 3}
    vec = torch.zeros(len(vocab))
    if text_cmd in vocab:
        vec[vocab[text_cmd]] = 1.0
    return vec

def pddl_validator(plan):
    return True

class HybridPlanner(nn.Module):
    def __init__(self, llm, rssm, max_depth=5):
        super().__init__()
        self.llm = llm
        self.rssm = rssm
        self.max_depth = max_depth

    @torch.no_grad()
    def plan(self, goal_desc, init_state):
        prelim_plan = self.llm.chain_of_thought(goal_desc)
        if not pddl_validator(prelim_plan):
            prelim_plan = self.llm.fix(prelim_plan)
        return self.imagine(prelim_plan, init_state)

    def imagine(self, plan, state):
        traj, total_reward = [], 0.0
        for t in range(self.max_depth):
            act = textual_policy_to_action(plan[t])
            _, prior_stats, state = self.rssm(state, act, obs_embed=torch.zeros(1024))
            state = reparam(prior_stats)
            r_t = reward_head(state)
            traj.append((state, act, r_t))
            total_reward += r_t
        return traj, total_reward

# ---------- 5. Runner --------------------------------------------------------
if __name__ == "__main__":
    vision   = VisionEncoder()
    text     = TextEncoder("spiece.model")  # Provide actual vocab file path
    memory   = EpisodicMemory()
    rssm     = LatentRSSM()
    llm      = LLMStub()
    planner  = HybridPlanner(llm, rssm)

    init_state = torch.randn(256)
    goal = "navigate obstacle course"
    traj, reward = planner.plan(goal, init_state)

    print(f"Planned {len(traj)} steps | Total Reward: {reward:.4f}")