In [10]:
# @title Unified Cognitive Architecture — single drop-in cell (Colab)
# If you restart runtime, just re-run this one cell.

# ============ 0) Environment: pin a consistent PyTorch trio (CUDA 12.6) ============
import sys, subprocess, textwrap, time, os, pathlib, json, math, random
def _sh(cmd): subprocess.check_call(cmd, shell=True)
# Install the PyTorch *trio* on the cu126 index to avoid version skew.
_sh("pip -q install 'torch==2.8.*' 'torchvision==0.23.*' 'torchaudio==2.8.*' --index-url https://download.pytorch.org/whl/cu126")

# ============ 1) Write full module to uca.py ======================================
from pathlib import Path
module_src = r'''
"""
UNIFIED COGNITIVE ARCHITECTURE - Google Colab Edition (FIXED)
===============================================================

Complete production-ready cognitive system in a single module.

Architecture:
  INPUT → L0: Perception → L1: Representation → L2: Dynamics
           ↓                                        ↓
          L3: Memory ← L4: Meta-Controller ← Loop?

This module preserves the structure of the notebook version while
providing a programmatic interface that can be tested automatically.
"""

from __future__ import annotations

from dataclasses import dataclass
from enum import IntEnum
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================================================
# SECTION 2: LAYER 0 - ADAPTIVE PERCEPTION
# ============================================================


class AdaptivePerception(nn.Module):
    """Resolution-aware perception front-end."""

    def __init__(
        self,
        input_channels: int = 3,
        hidden_dim: int = 256,
        resolutions: Optional[List[int]] = None,
    ) -> None:
        super().__init__()
        if resolutions is None:
            resolutions = [64, 256, 512]
        self.resolutions = sorted(resolutions)
        self.hidden_dim = hidden_dim

        self.encoders = nn.ModuleDict(
            {
                str(res): self._make_encoder(input_channels, hidden_dim)
                for res in self.resolutions
            }
        )

        # Budget thresholds (0-10 scale)
        self.budget_thresholds = torch.linspace(0, 10, len(self.resolutions) + 1)[1:]

    def _make_encoder(self, in_channels: int, hidden_dim: int) -> nn.Module:
        return nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, hidden_dim),
        )

    def select_resolution(self, budget: float) -> int:
        budget = torch.clamp(torch.tensor(budget), 0, 10)
        for res, thresh in zip(self.resolutions, self.budget_thresholds):
            if budget <= thresh:
                return res
        return self.resolutions[-1]

    def forward(
        self,
        x: torch.Tensor,
        budget: float = 5.0,
        return_info: bool = False,
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        resolution = self.select_resolution(budget)

        if x.shape[-1] != resolution:
            x = F.interpolate(
                x,
                size=(resolution, resolution),
                mode="bilinear",
                align_corners=False,
            )

        encoded = self.encoders[str(resolution)](x)

        if return_info:
            info = {
                "resolution": resolution,
                "budget": float(budget),
                "speedup": self.resolutions[-1] / resolution,
            }
            return encoded, info
        return encoded, {}


# ============================================================
# SECTION 3: LAYER 1 - SET TRANSFORMER REPRESENTATION
# ============================================================


class SetTransformer(nn.Module):
    def __init__(
        self,
        input_dim: int = 256,
        hidden_dim: int = 256,
        num_heads: int = 8,
        num_layers: int = 4,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()

        self.input_proj = nn.Linear(input_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True,
        )

        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
        self.pool_query = nn.Parameter(torch.randn(1, 1, hidden_dim))

    def forward(
        self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        B, N, _ = x.shape
        x = self.input_proj(x)
        transformed = self.transformer(x, src_key_padding_mask=mask)

        query = self.pool_query.expand(B, 1, -1)
        pooled = torch.matmul(query, transformed.transpose(1, 2))

        if mask is not None:
            pooled = pooled.masked_fill(mask.unsqueeze(1), -1e9)

        weights = torch.softmax(pooled, dim=-1)
        aggregated = torch.matmul(weights, transformed)
        return self.output_proj(aggregated.squeeze(1))


# ============================================================
# SECTION 4: LAYER 2 - ACTIVE INFERENCE DYNAMICS
# ============================================================


class ActiveInferenceModule(nn.Module):
    def __init__(
        self,
        obs_dim: int = 256,
        hidden_dim: int = 256,
        latent_dim: int = 64,
        num_layers: int = 2,
    ) -> None:
        super().__init__()

        self.latent_dim = latent_dim

        self.prior_net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2),
        )

        self.posterior_net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2),
        )

        self.generative_net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, obs_dim),
        )

        self.dynamics = nn.LSTM(latent_dim, latent_dim, num_layers, batch_first=True)
        self.output_proj = nn.Linear(latent_dim, latent_dim)

    def encode(self, obs: torch.Tensor, use_prior: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        net = self.prior_net if use_prior else self.posterior_net
        params = net(obs)
        mean, logvar = torch.chunk(params, 2, dim=-1)
        return mean, logvar

    def reparameterize(self, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def decode(self, latent: torch.Tensor) -> torch.Tensor:
        return self.generative_net(latent)

    def compute_free_energy(
        self,
        obs: torch.Tensor,
        post_mean: torch.Tensor,
        post_logvar: torch.Tensor,
        prior_mean: torch.Tensor,
        prior_logvar: torch.Tensor,
    ) -> torch.Tensor:
        latent = self.reparameterize(post_mean, post_logvar)
        recon = self.decode(latent)
        accuracy = -torch.mean((obs - recon) ** 2, dim=-1)

        complexity = -0.5 * torch.sum(
            1
            + post_logvar
            - prior_logvar
            - ((post_mean - prior_mean) ** 2 + torch.exp(post_logvar))
            / torch.exp(prior_logvar),
            dim=-1,
        )
        return (complexity - accuracy).mean()

    def forward(self, obs: torch.Tensor, num_iterations: int = 10) -> Dict[str, object]:
        prior_mean, prior_logvar = self.encode(obs, use_prior=True)

        free_energies: List[float] = []
        latent_history: List[torch.Tensor] = []
        current_latent = self.reparameterize(prior_mean, prior_logvar)

        for _ in range(num_iterations):
            post_mean, post_logvar = self.encode(obs)
            fe = self.compute_free_energy(obs, post_mean, post_logvar, prior_mean, prior_logvar)
            free_energies.append(float(fe.item()))
            current_latent = self.reparameterize(post_mean, post_logvar)
            latent_history.append(current_latent)
            prior_mean = 0.9 * prior_mean + 0.1 * post_mean
            prior_logvar = 0.9 * prior_logvar + 0.1 * post_logvar

        if len(latent_history) > 1:
            latent_seq = torch.stack(latent_history, dim=1)
            _, (h_n, _) = self.dynamics(latent_seq)
            final_latent = self.output_proj(h_n[-1])
        else:
            final_latent = self.output_proj(current_latent)

        return {
            "latent": final_latent,
            "free_energy": free_energies[-1],
            "free_energy_history": free_energies,
        }


# ============================================================
# SECTION 5: LAYER 3 - MEMORY SYSTEM
# ============================================================


class TitansWorkingMemory(nn.Module):
    def __init__(self, hidden_dim: int = 256, num_slots: int = 1024, lr: float = 0.1) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_slots = num_slots
        self.lr = lr

        self.register_buffer("memory", torch.zeros(num_slots, hidden_dim))
        self.register_buffer("access_count", torch.zeros(num_slots))

        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)

        self.slot_idx = 0

    def forward(self, query: torch.Tensor, update: bool = False) -> Tuple[torch.Tensor, Dict[str, float]]:
        B = query.shape[0]

        Q = self.query_proj(query)
        scores = torch.matmul(Q, self.memory.T) / (self.hidden_dim ** 0.5)
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, self.memory)

        if update:
            V = self.value_proj(query)
            for i in range(B):
                if self.slot_idx < self.num_slots:
                    slot = self.slot_idx
                    self.slot_idx += 1
                else:
                    slot = self.access_count.argmin().item()

                self.memory[slot] = self.memory[slot] * (1 - self.lr) + V[i] * self.lr
                self.access_count[slot] += 1

        info = {
            "memory_usage": self.slot_idx / self.num_slots,
            "avg_access": float(self.access_count.mean().item()),
        }
        return output, info

    def reset(self) -> None:
        self.memory.zero_()
        self.access_count.zero_()
        self.slot_idx = 0


class MAPElitesArchive:
    def __init__(self, descriptor_dim: int = 2, grid_bins: int = 10) -> None:
        self.descriptor_dim = descriptor_dim
        self.grid_bins = grid_bins
        self.archive: Dict[Tuple[int, ...], Dict[str, object]] = {}
        self.descriptor_min: Optional[np.ndarray] = None
        self.descriptor_max: Optional[np.ndarray] = None

    def _discretize(self, descriptor: np.ndarray) -> Tuple[int, ...]:
        if self.descriptor_min is None:
            self.descriptor_min = descriptor.copy()
            self.descriptor_max = descriptor.copy()
        else:
            self.descriptor_min = np.minimum(self.descriptor_min, descriptor)
            self.descriptor_max = np.maximum(self.descriptor_max, descriptor)

        ranges = self.descriptor_max - self.descriptor_min
        ranges = np.where(ranges == 0, 1, ranges)
        normalized = (descriptor - self.descriptor_min) / ranges
        bins = (normalized * (self.grid_bins - 1)).astype(int)
        bins = np.clip(bins, 0, self.grid_bins - 1)
        return tuple(bins)

    def add(self, solution: torch.Tensor, fitness: float, descriptor: np.ndarray) -> bool:
        cell = self._discretize(descriptor)
        if cell not in self.archive or fitness > self.archive[cell]["fitness"]:
            self.archive[cell] = {
                "solution": solution.detach().cpu(),
                "fitness": fitness,
                "descriptor": descriptor,
            }
            return True
        return False

    def get_statistics(self) -> Dict[str, float]:
        if len(self.archive) == 0:
            return {"size": 0, "coverage": 0.0, "avg_fitness": 0.0}

        fitnesses = [entry["fitness"] for entry in self.archive.values()]
        return {
            "size": len(self.archive),
            "coverage": len(self.archive) / (self.grid_bins ** self.descriptor_dim),
            "avg_fitness": float(np.mean(fitnesses)),
            "max_fitness": float(np.max(fitnesses)),
        }

    def retrieve(self, query_descriptor: np.ndarray, k: int = 5):
        if len(self.archive) == 0:
            return []
        descs, cells = [], []
        for cell, entry in self.archive.items():
            descs.append(entry["descriptor"])
            cells.append(cell)
        descs = np.stack(descs, axis=0)
        dists = np.linalg.norm(descs - query_descriptor[None, :], axis=1)
        order = np.argsort(dists)[:k]
        return [self.archive[cells[i]] for i in order]


class MemorySystem(nn.Module):
    def __init__(
        self,
        hidden_dim: int = 256,
        working_slots: int = 1024,
        archive_bins: int = 10,
        descriptor_dim: int = 2,
    ) -> None:
        super().__init__()
        self.working_memory = TitansWorkingMemory(hidden_dim, working_slots)
        self.long_term_memory = MAPElitesArchive(descriptor_dim, archive_bins)
        self.descriptor_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, descriptor_dim),
            nn.Tanh(),
        )

    def compute_descriptor(self, latent: torch.Tensor) -> np.ndarray:
        """Project a latent vector into descriptor space.

        Supports both batched and 1D inputs.
        """
        if latent.dim() == 1:
            latent = latent.unsqueeze(0)
        with torch.no_grad():
            desc = self.descriptor_net(latent)
        return desc.cpu().numpy()

    def forward(
        self, query: torch.Tensor, update_working: bool = False
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        return self.working_memory(query, update=update_working)

    def store_solution(self, solution: torch.Tensor, fitness: float) -> bool:
        descriptor = self.compute_descriptor(solution)
        return self.long_term_memory.add(solution, float(fitness), descriptor[0])

    def retrieve_from_archive(self, query_latent: torch.Tensor, k: int = 5):
        if query_latent.dim() == 1:
            query_latent = query_latent.unsqueeze(0)
        with torch.no_grad():
            qd = self.descriptor_net(query_latent).cpu().numpy()
        return self.long_term_memory.retrieve(qd[0], k=k)


# ============================================================
# SECTION 6: LAYER 4 - META-CONTROLLER
# ============================================================


class Action(IntEnum):
    THINK = 0
    RETRIEVE = 1
    PERCEIVE_UP = 2
    PERCEIVE_DOWN = 3
    VERIFY = 4
    STORE = 5
    EXIT = 6


class MetaController(nn.Module):
    def __init__(self, state_dim: int = 512, hidden_dim: int = 256, num_actions: int = 7) -> None:
        super().__init__()
        self.state_dim = state_dim
        self.num_actions = num_actions

        self.policy_net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_actions),
        )

        self.value_net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def encode_state(
        self,
        task_embedding: torch.Tensor,
        progress: float,
        budget: float,
        confidence: float,
        memory_usage: float,
        iteration: int,
    ) -> torch.Tensor:
        B = task_embedding.shape[0]
        scalars = torch.tensor(
            [progress, budget / 10.0, confidence, memory_usage, iteration / 50.0],
            device=task_embedding.device,
        ).unsqueeze(0)
        scalars = scalars.expand(B, -1)

        state = torch.cat([task_embedding, scalars], dim=-1)
        if state.shape[-1] < self.state_dim:
            padding = torch.zeros(B, self.state_dim - state.shape[-1], device=state.device)
            state = torch.cat([state, padding], dim=-1)
        return state[:, : self.state_dim]

    def forward(self, state: torch.Tensor, deterministic: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        logits = self.policy_net(state)
        probs = torch.softmax(logits, dim=-1)
        if deterministic:
            actions = torch.argmax(probs, dim=-1)
            log_probs = torch.log(probs.gather(1, actions.unsqueeze(1))).squeeze(1)
        else:
            dist = torch.distributions.Categorical(probs)
            actions = dist.sample()
            log_probs = dist.log_prob(actions)
        return actions, log_probs

    def get_value(self, state: torch.Tensor) -> torch.Tensor:
        return self.value_net(state).squeeze(-1)


# ============================================================
# SECTION 7: UNIFIED SYSTEM (FIXED)
# ============================================================


class UnifiedCognitiveSystem(nn.Module):
    def __init__(
        self,
        input_channels: int = 3,
        obs_dim: int = 256,
        hidden_dim: int = 256,
        latent_dim: int = 64,
        resolutions: Optional[List[int]] = None,
        memory_slots: int = 1024,
        archive_bins: int = 10,
        use_meta_controller: bool = False,
    ) -> None:
        super().__init__()
        if resolutions is None:
            resolutions = [64, 256, 512]

        self.obs_dim = obs_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.use_meta_controller = use_meta_controller

        self.perception = AdaptivePerception(input_channels, obs_dim, resolutions)
        self.representation = SetTransformer(obs_dim, hidden_dim, num_heads=8, num_layers=4)
        self.dynamics = ActiveInferenceModule(hidden_dim, hidden_dim, latent_dim)
        self.memory = MemorySystem(latent_dim, memory_slots, archive_bins, descriptor_dim=2)

        if use_meta_controller:
            self.meta_controller = MetaController(state_dim=512, hidden_dim=256)

        self.output_proj = nn.Linear(latent_dim, obs_dim)

    def forward(
        self,
        inputs: Dict[str, torch.Tensor],
        budget: float = 5.0,
        num_iterations: int = 10,
        update_memory: bool = True,
        return_info: bool = True,
    ) -> Dict[str, object]:
        info: Dict[str, object] = {}

        # Perception for image if present
        if "image" in inputs:
            perceived, perc_info = self.perception(inputs["image"], budget, return_info=True)
            info["perception"] = perc_info
        else:
            # Fallback: use the first provided modality directly
            perceived = next(iter(inputs.values()))
            info["perception"] = {"modalities": len(inputs)}

        # Treat available modalities as a set (image + others), projecting to obs_dim if needed
        set_elems = [perceived]
        for k, v in inputs.items():
            if k == "image":
                continue
            t = v
            if t.dim() == 1:
                t = t.unsqueeze(0)
            if t.shape[-1] != self.obs_dim:
                proj = getattr(self, f"_proj_{k}", None)
                if proj is None:
                    proj = nn.Linear(t.shape[-1], self.obs_dim).to(t.device)
                    setattr(self, f"_proj_{k}", proj)
                t = proj(t)
            set_elems.append(t)
        perceived_set = torch.stack(set_elems, dim=1)  # [B, N, obs_dim]

        represented = self.representation(perceived_set)
        info["representation"] = {"hidden_dim": represented.shape[-1], "set_size": perceived_set.shape[1]}

        dynamics_result = self.dynamics(represented, num_iterations)
        latent = dynamics_result["latent"]
        info["dynamics"] = {
            "free_energy": dynamics_result["free_energy"],
            "iterations": num_iterations,
        }

        memory_out, memory_info = self.memory(latent, update_working=update_memory)
        info["memory"] = memory_info

        combined = latent + memory_out
        output = self.output_proj(combined)

        result: Dict[str, object] = {"output": output}
        if return_info:
            result["info"] = info
            result["latent"] = latent
        return result

    def store_in_archive(self, solution: torch.Tensor, fitness: float) -> bool:
        return self.memory.store_solution(solution, fitness)

    def get_statistics(self) -> Dict[str, float]:
        archive_stats = self.memory.long_term_memory.get_statistics()
        return {
            "perception_resolutions": self.perception.resolutions,
            "archive_size": archive_stats["size"],
            "archive_coverage": archive_stats["coverage"],
            "avg_fitness": archive_stats.get("avg_fitness", 0.0),
            "hidden_dim": self.hidden_dim,
            "latent_dim": self.latent_dim,
        }


@dataclass
class Config:
    """Configuration presets for different operating regimes."""

    input_channels: int = 3
    obs_dim: int = 256
    hidden_dim: int = 256
    latent_dim: int = 64
    resolutions: Optional[List[int]] = None
    memory_slots: int = 1024
    archive_bins: int = 10
    use_meta_controller: bool = False

    @staticmethod
    def mvp() -> Dict[str, object]:
        return {
            "input_channels": 3,
            "obs_dim": 64,
            "hidden_dim": 128,
            "latent_dim": 32,
            "resolutions": [64, 256],
            "memory_slots": 256,
            "archive_bins": 10,
            "use_meta_controller": False,
        }

    @staticmethod
    def production() -> Dict[str, object]:
        return {
            "input_channels": 3,
            "obs_dim": 256,
            "hidden_dim": 256,
            "latent_dim": 64,
            "resolutions": [64, 256, 512],
            "memory_slots": 1024,
            "archive_bins": 20,
            "use_meta_controller": False,
        }

    @staticmethod
    def research() -> Dict[str, object]:
        return {
            "input_channels": 3,
            "obs_dim": 512,
            "hidden_dim": 512,
            "latent_dim": 128,
            "resolutions": [64, 256, 512, 1024],
            "memory_slots": 4096,
            "archive_bins": 50,
            "use_meta_controller": False,
        }


__all__ = [
    "AdaptivePerception",
    "SetTransformer",
    "ActiveInferenceModule",
    "TitansWorkingMemory",
    "MAPElitesArchive",
    "MemorySystem",
    "MetaController",
    "UnifiedCognitiveSystem",
    "Config",
    "Action",
]
'''
Path("uca.py").write_text(module_src)

# ============ 2) Import + quick smoke ==========================================
import torch, numpy as np, time
from uca import UnifiedCognitiveSystem

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device, "| torch:", torch.__version__, "| CUDA:", torch.version.cuda)

cfg = dict(input_channels=3, obs_dim=256, hidden_dim=256, latent_dim=64, resolutions=[64,256])
sysmod = UnifiedCognitiveSystem(**cfg).to(device).eval()

# multimodal toy inputs
B = 2
inputs = {
    "image": torch.randn(B,3,224,224, device=device),
    "text":  torch.randn(B,256, device=device),
    "audio": torch.randn(B,256, device=device),
}

# forward and iteration sweep (test-time scaling)
for iters in [1, 5, 10, 20]:
    t0=time.time()
    out = sysmod.forward(inputs, budget=5.0, num_iterations=iters, update_memory=True, return_info=True)
    ms = (time.time()-t0)*1000
    fe = out["info"]["dynamics"]["free_energy"]
    set_sz = out["info"]["representation"]["set_size"]
    print(f"{iters:>2} iters | {ms:6.1f} ms | FE {fe:8.4f} | set_size={set_sz}")

# archive round-trip
latent = out["latent"].detach()
stored = sysmod.store_in_archive(latent[0], fitness=float(np.random.rand()))
neighbors = sysmod.memory.retrieve_from_archive(latent[0], k=3)
print(f"Archive store: {stored} | retrieved: {len(neighbors)} | stats:", sysmod.get_statistics())


ImportError: cannot import name 'UnifiedCognitiveSystem' from 'uca' (/content/uca.py)