<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/unified_ai_memory_eem_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# unified_ai/memory/eem.py
# Self-contained: complex episodic memory + reflection scaffold + smoke test.

from __future__ import annotations
from typing import Optional, Tuple, Callable, Dict, Any, List

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def _as_complex(r: torch.Tensor, i: Optional[torch.Tensor] = None) -> torch.Tensor:
    if i is None:
        i = torch.zeros_like(r)
    return torch.complex(r, i)


def _norm_complex(z: torch.Tensor, eps: float = 1e-6, dim: int = -1) -> torch.Tensor:
    # Normalize complex vector along dim: z / ||z||
    mag = torch.sqrt((z.real**2 + z.imag**2).sum(dim=dim, keepdim=True) + eps)
    return z / mag


def _inner_complex(a: torch.Tensor, b: torch.Tensor, dim: int = -1) -> torch.Tensor:
    # ⟨a, b⟩ = sum(conj(a) * b) along dim
    return torch.sum(torch.conj(a) * b, dim=dim)


def _phase_unitary(z: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
    # Element-wise complex phase: e^{iθ} ⊙ z
    # z: (..., D) complex; theta: (D,) real
    phase = torch.complex(torch.cos(theta), torch.sin(theta))  # (D,)
    return z * phase


def _stack_householder(U: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    # Apply Householder reflection H(v) = I - 2 vv^T / (v^T v) on real matrix U
    # v: (D,) real; U: (..., D) real; reflection along last dim.
    v = v / (v.norm() + 1e-6)
    # Compute U - 2*(U v) v^T
    proj = torch.matmul(U, v)  # (..., )
    return U - 2.0 * proj.unsqueeze(-1) * v


class EntangledEpisodicMemory(nn.Module):
    """
    Entangled episodic memory in a complex Hilbert space.

    - Keys live in C^D (stored as real & imag).
    - Retrieval probabilities are "measurement" over amplitudes: p_i ∝ |⟨q, k_i⟩|^2 / τ.
    - A learnable unitary U(θ) modulates phase before similarity.
    - Reads are differentiable; writes are no-grad EMA or ring buffer.

    Shapes:
      - keys: (S, D), values: (S, V)
      - read:
          q: (B, D) real or complex (real inputs are promoted to complex)
          returns: (B, V)
      - write:
          k: (B, D) real or complex, v: (B, V)

    Args:
      slots: memory slots S
      key_dim: complex key dimension D
      value_dim: value dimension V (real)
      temperature: softmax temperature τ
      ema: EMA mixing for 'nearest' writes
      trainable_memory: if True, keys/values are learnable parameters; else buffers (typical)
      use_householder: if >0, apply that many real Householder reflections on [Re; Im] stack
    """
    def __init__(
        self,
        slots: int = 4096,
        key_dim: int = 512,
        value_dim: int = 512,
        temperature: float = 0.1,
        ema: float = 0.1,
        trainable_memory: bool = False,
        householder_layers: int = 0,
        dtype: torch.dtype = torch.float32,
        device: Optional[torch.device] = None,
    ):
        super().__init__()
        self.slots = slots
        self.key_dim = key_dim
        self.value_dim = value_dim
        self.temperature = temperature
        self.ema = ema
        self.householder_layers = householder_layers

        # Memory storage
        init_keys_r = F.normalize(torch.randn(slots, key_dim, dtype=dtype, device=device), dim=-1)
        init_keys_i = F.normalize(torch.randn(slots, key_dim, dtype=dtype, device=device), dim=-1)
        init_values = torch.zeros(slots, value_dim, dtype=dtype, device=device)

        if trainable_memory:
            self.keys_r = nn.Parameter(init_keys_r)
            self.keys_i = nn.Parameter(init_keys_i)
            self.values = nn.Parameter(init_values)
        else:
            self.register_buffer("keys_r", init_keys_r)
            self.register_buffer("keys_i", init_keys_i)
            self.register_buffer("values", init_values)

        # Learnable unitary: diagonal phase θ in R^D
        self.theta = nn.Parameter(torch.zeros(key_dim, dtype=dtype, device=device))

        # Optional Householder vectors (real reflections on concatenated [Re(z); Im(z)])
        if self.householder_layers > 0:
            self.house_v = nn.ParameterList(
                [nn.Parameter(F.normalize(torch.randn(2 * key_dim, dtype=dtype, device=device), dim=0))
                 for _ in range(self.householder_layers)]
            )

        # Ring buffer pointer + age tracker (buffers)
        self.register_buffer("age", torch.zeros(slots, dtype=torch.long, device=device))
        self.register_buffer("ptr", torch.zeros((), dtype=torch.long, device=device))

    # ------------- Internals -------------
    def _apply_unitary(self, z_c: torch.Tensor) -> torch.Tensor:
        """
        Apply U to complex vector(s): phase rotation + optional stacked Householder over [Re; Im].
        z_c: (..., D) complex
        """
        z_c = _phase_unitary(z_c, self.theta)  # element-wise unitary in C^D
        if self.householder_layers > 0:
            # Real-augment: concat [Re; Im] along feature dim
            re, im = z_c.real, z_c.imag  # (..., D)
            cat = torch.cat([re, im], dim=-1)  # (..., 2D)
            for v in self.house_v:
                cat = _stack_householder(cat, v)
            # Split back to complex
            D = z_c.size(-1)
            re2, im2 = cat[..., :D], cat[..., D:]
            z_c = torch.complex(re2, im2)
        return z_c

    def _similarity(self, q_c: torch.Tensor, k_c: torch.Tensor, mode: str = "inner") -> torch.Tensor:
        """
        Compute similarity scores between a batch of queries and all keys.

        q_c: (B, D) complex
        k_c: (S, D) complex
        returns: (B, S) real scores
        """
        # Normalize to lie on complex unit sphere
        qn = _norm_complex(q_c, dim=-1)
        kn = _norm_complex(k_c, dim=-1)
        if mode == "inner":
            # Use squared magnitude of inner product as "measurement" amplitude
            # scores[b, s] = |⟨q_b, k_s⟩|^2
            scores = torch.abs(qn.unsqueeze(1) @ torch.conj(kn).unsqueeze(0).transpose(-1, -2)) ** 2
            # qn.unsqueeze(1): (B,1,D), conj(kn).unsqueeze(0).T: (1,D,S) => (B,1,S) via matmul
            scores = scores.squeeze(1)  # (B, S)
        elif mode == "cos":
            # Real cosine on concatenated [Re; Im]
            qr = torch.cat([qn.real, qn.imag], dim=-1)  # (B, 2D)
            kr = torch.cat([kn.real, kn.imag], dim=-1)  # (S, 2D)
            scores = F.cosine_similarity(qr.unsqueeze(1), kr.unsqueeze(0), dim=-1)  # (B, S)
        else:
            raise ValueError(f"Unknown similarity mode: {mode}")
        return scores

    # ------------- Public API -------------
    @torch.no_grad()
    def write(self, k: torch.Tensor, v: torch.Tensor, strategy: str = "nearest"):
        """
        Write key-value pairs into memory.
        k: (B, D) real or complex; v: (B, V) real
        strategy: 'nearest' (EMA into nearest slot), or 'ring' (append)
        """
        single = k.dim() == 1
        if single:
            k = k.unsqueeze(0)
            v = v.unsqueeze(0)

        # Promote to complex if needed, then apply unitary to canonicalize writer-side keys
        if torch.is_complex(k):
            k_c = k
        else:
            k_c = _as_complex(k)
        k_c = self._apply_unitary(k_c)
        k_c = _norm_complex(k_c)

        # Memory views
        mem_c = _as_complex(self.keys_r, self.keys_i)  # (S, D)

        if strategy == "ring":
            for i in range(k_c.size(0)):
                idx = int(self.ptr.item() % self.slots)
                self.keys_r[idx] = k_c[i].real
                self.keys_i[idx] = k_c[i].imag
                self.values[idx] = v[i]
                self.age[idx] = 0
                self.ptr += 1
        elif strategy == "nearest":
            # Find nearest slots by measurement similarity
            sims = self._similarity(k_c, mem_c)  # (B, S)
            idxs = sims.argmax(dim=-1)  # (B,)
            for i, idx in enumerate(idxs.tolist()):
                # EMA update
                self.keys_r[idx] = F.normalize((1 - self.ema) * self.keys_r[idx] + self.ema * k_c[i].real, dim=-1)
                self.keys_i[idx] = F.normalize((1 - self.ema) * self.keys_i[idx] + self.ema * k_c[i].imag, dim=-1)
                self.values[idx] = (1 - self.ema) * self.values[idx] + self.ema * v[i]
                self.age[idx] = 0
        else:
            raise ValueError(f"Unknown write strategy: {strategy}")

        self.age += 1

    def read(
        self,
        q: torch.Tensor,
        topk: int = 0,
        similarity: str = "inner",
        return_weights: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Read from memory with measurement probabilities.

        q: (B, D) real or complex input query
        topk: if >0, restrict softmax to top-k keys per query
        similarity: 'inner' (|⟨q,k⟩|^2) or 'cos' on [Re;Im]
        return_weights: also return attention weights (B, S or B, K)

        Returns:
          values: (B, V)
          weights (optional): (B, S) or (B, K)
        """
        single = q.dim() == 1
        if single:
            q = q.unsqueeze(0)

        # Promote real to complex, apply unitary on the query side
        q_c = q if torch.is_complex(q) else _as_complex(q)
        q_c = self._apply_unitary(q_c)

        # Current memory keys
        k_c = _as_complex(self.keys_r, self.keys_i)

        scores = self._similarity(q_c, k_c, mode=similarity)  # (B, S)
        scores = scores / max(self.temperature, 1e-6)

        if topk and topk < self.slots:
            vals, idxs = scores.topk(topk, dim=-1)  # (B, K)
            w = F.softmax(vals, dim=-1)            # (B, K)
            gathered = self.values[idxs]           # (B, K, V)
            out = (w.unsqueeze(-1) * gathered).sum(dim=1)  # (B, V)
            weights = w
        else:
            w = F.softmax(scores, dim=-1)  # (B, S)
            out = w @ self.values          # (B, V)
            weights = w

        if single:
            out = out.squeeze(0)
            if return_weights:
                weights = weights.squeeze(0)
        return (out, weights) if return_weights else (out, None)

    # Convenience: functional multi-read with different topk/temperatures without rebuilding the module
    def attend(self, q: torch.Tensor, *, topk: int = 0, temperature: Optional[float] = None) -> torch.Tensor:
        old_tau = self.temperature
        if temperature is not None:
            self.temperature = temperature
        v, _ = self.read(q, topk=topk)
        if temperature is not None:
            self.temperature = old_tau
        return v


# -------------------------
# Reflection scaffolding
# -------------------------

def reflect_on_episode(
    agent_log: str,
    prompt_template: str,
    llm_fn: Optional[Callable[[str], str]] = None,
    parse_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
) -> Dict[str, Any]:
    """
    Uses an LLM (or a provided function) to generate introspective commentary.

    Args:
      agent_log: raw transcript text (actions, states, outcomes)
      prompt_template: f-string-like template with '{log}' placeholder
      llm_fn: function(prompt) -> str (defaults to a heuristic stub)
      parse_fn: function(text) -> dict (defaults to a robust JSON-ish parser stub)

    Returns:
      dict with keys: failures (list), hypotheses (list), improvements (list), raw (str)
    """
    if "{log}" not in prompt_template:
        raise ValueError("prompt_template must contain '{log}' placeholder")

    prompt = prompt_template.format(log=agent_log)

    # Default LLM: a simple heuristic splitter if none provided
    def _default_llm(p: str) -> str:
        lines = [ln.strip() for ln in p.splitlines() if ln.strip()]
        # naive patterns
        failures = [ln for ln in lines if "fail" in ln.lower() or "error" in ln.lower()]
        hyp = [
            "Insufficient situational memory caused plan drift.",
            "Overconfident prior led to under-exploration.",
        ]
        imps = [
            "Increase retrieval top-k and use recency-weighted writes.",
            "Inject uncertainty-aware exploration and replan when surprise spikes.",
        ]
        return (
            "Failures:\n- " + "\n- ".join(failures[:3] or ["No explicit failures found; inspect reward shaping."]) + "\n\n"
            "Hypotheses:\n- " + "\n- ".join(hyp) + "\n\n"
            "Improvements:\n- " + "\n- ".join(imps)
        )

    # Default parser: extract sections
    def _default_parse(s: str) -> Dict[str, Any]:
        out = {"failures": [], "hypotheses": [], "improvements": [], "raw": s}
        sec, buf = None, []
        def _flush():
            if sec and buf:
                items = [b.lstrip("- ").strip() for b in buf if b]
                out[sec] = items
        for ln in s.splitlines():
            l = ln.strip()
            if not l: continue
            lower = l.lower()
            if lower.startswith("failures:"):
                _flush(); sec, buf = "failures", []
            elif lower.startswith("hypotheses:"):
                _flush(); sec, buf = "hypotheses", []
            elif lower.startswith("improvements:"):
                _flush(); sec, buf = "improvements", []
            else:
                buf.append(l)
        _flush()
        return out

    llm = llm_fn or _default_llm
    parser = parse_fn or _default_parse
    raw = llm(prompt)
    return parser(raw)


# -------------------------
# Smoke test
# -------------------------

def _smoke():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(7)

    # Memory
    mem = EntangledEpisodicMemory(
        slots=128, key_dim=32, value_dim=16, temperature=0.2, ema=0.2,
        trainable_memory=False, householder_layers=1, device=device
    )

    # Create two clusters of keys/values and write them
    B = 16
    k1 = F.normalize(torch.randn(B, 32, device=device) + 2.0, dim=-1)
    k2 = F.normalize(torch.randn(B, 32, device=device) - 2.0, dim=-1)
    v1 = torch.ones(B, 16, device=device)        # value signature for cluster 1
    v2 = torch.zeros(B, 16, device=device) + 3.0 # value signature for cluster 2

    mem.write(k1, v1, strategy="ring")
    mem.write(k2, v2, strategy="ring")

    # Queries near cluster 1
    q = k1[:4] + 0.05 * torch.randn(4, 32, device=device)
    out, w = mem.read(q, topk=8, return_weights=True)
    assert out.shape == (4, 16)
    print("[PASS] Read shape:", out.shape, "| weights:", None if w is None else w.shape)

    # Check that outputs lean towards v1 (~1's) rather than v2 (~3's)
    m = out.mean().item()
    print(f"[INFO] Output mean ~ {m:.3f} (closer to 1.0 => matched cluster 1)")

    # Reflection test
    tmpl = """You are a meta-cognitive agent.

Below is a transcript of your actions, states, and outcomes during Task #274.

Please:
1. Identify at least 2 failure points or suboptimal decisions.
2. Hypothesize why these failures occurred.
3. Propose improvements to policy, memory use, or planning strategy.

Transcript:
{log}"""
    log = """
Step 5: Missed object due to stale plan. Repeated scan without updating map.
Step 12: Reached dead-end; planner failed to backtrack.
Outcome: Goal not achieved; time limit exceeded.
"""
    reflection = reflect_on_episode(log, tmpl)
    print("[REFLECTION] Failures:", reflection["failures"][:2])
    print("[REFLECTION] Improvements (sample):", reflection["improvements"][:2])


if __name__ == "__main__":
    _smoke()