<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/unified_ai_memory_eem_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# unified_ai/memory/eem.py
from __future__ import annotations
from typing import Optional, Tuple, Callable, Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F

def _as_complex(r: torch.Tensor, i: Optional[torch.Tensor] = None) -> torch.Tensor:
    if i is None:
        i = torch.zeros_like(r)
    return torch.complex(r, i)

def _norm_complex(z: torch.Tensor, eps: float = 1e-6, dim: int = -1) -> torch.Tensor:
    mag = torch.sqrt((z.real**2 + z.imag**2).sum(dim=dim, keepdim=True) + eps)
    return z / mag

def _phase_unitary(z: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
    phase = torch.complex(torch.cos(theta), torch.sin(theta))
    return z * phase

def _stack_householder(U: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    v = v / (v.norm() + 1e-6)
    proj = torch.matmul(U, v)
    return U - 2.0 * proj.unsqueeze(-1) * v

class EntangledEpisodicMemory(nn.Module):
    """
    Complex Hilbert-space key-value memory with measurement-based retrieval.

    - keys: (S, D_c) complex (stored as real/imag)
    - values: (S, D_v) real

    Read:
      scores: |<q,k>|^2 / tau -> softmax -> weighted sum of values
    Write:
      'ring' append or 'nearest' EMA into best-matching slot (no-grad).
    """
    def __init__(
        self,
        slots: int = 256,
        key_dim: int = 32,
        value_dim: int = 16,
        temperature: float = 0.2,
        ema: float = 0.2,
        trainable_memory: bool = False,
        householder_layers: int = 1,
        dtype: torch.dtype = torch.float32,
        device: Optional[torch.device] = None,
    ):
        super().__init__()
        self.slots = slots
        self.key_dim = key_dim
        self.value_dim = value_dim
        self.temperature = temperature
        self.ema = ema
        self.householder_layers = householder_layers

        init_keys_r = F.normalize(torch.randn(slots, key_dim, dtype=dtype, device=device), dim=-1)
        init_keys_i = F.normalize(torch.randn(slots, key_dim, dtype=dtype, device=device), dim=-1)
        init_values = torch.zeros(slots, value_dim, dtype=dtype, device=device)

        if trainable_memory:
            self.keys_r = nn.Parameter(init_keys_r)
            self.keys_i = nn.Parameter(init_keys_i)
            self.values = nn.Parameter(init_values)
        else:
            self.register_buffer("keys_r", init_keys_r)
            self.register_buffer("keys_i", init_keys_i)
            self.register_buffer("values", init_values)

        self.theta = nn.Parameter(torch.zeros(key_dim, dtype=dtype, device=device))

        if self.householder_layers > 0:
            self.house_v = nn.ParameterList(
                [nn.Parameter(F.normalize(torch.randn(2 * key_dim, dtype=dtype, device=device), dim=0))
                 for _ in range(self.householder_layers)]
            )

        self.register_buffer("age", torch.zeros(slots, dtype=torch.long, device=device))
        self.register_buffer("ptr", torch.zeros((), dtype=torch.long, device=device))

    def _apply_unitary(self, z_c: torch.Tensor) -> torch.Tensor:
        z_c = _phase_unitary(z_c, self.theta)
        if self.householder_layers > 0:
            re, im = z_c.real, z_c.imag
            cat = torch.cat([re, im], dim=-1)
            for v in self.house_v:
                cat = _stack_householder(cat, v)
            D = z_c.size(-1)
            z_c = torch.complex(cat[..., :D], cat[..., D:])
        return z_c

    def _similarity(self, q_c: torch.Tensor, k_c: torch.Tensor) -> torch.Tensor:
        qn = _norm_complex(q_c)
        kn = _norm_complex(k_c)
        # (B,1,D) @ (1,D,S) -> (B,1,S) -> (B,S)
        scores = torch.abs(qn.unsqueeze(1) @ torch.conj(kn).unsqueeze(0).transpose(-1, -2)) ** 2
        return scores.squeeze(1)

    @torch.no_grad()
    def write(self, k: torch.Tensor, v: torch.Tensor, strategy: str = "nearest"):
        single = k.dim() == 1
        if single:
            k = k.unsqueeze(0); v = v.unsqueeze(0)
        k_c = k if torch.is_complex(k) else _as_complex(k)
        k_c = _norm_complex(self._apply_unitary(k_c))
        mem_c = _as_complex(self.keys_r, self.keys_i)

        if strategy == "ring":
            for i in range(k_c.size(0)):
                idx = int(self.ptr.item() % self.slots)
                self.keys_r[idx] = k_c[i].real
                self.keys_i[idx] = k_c[i].imag
                self.values[idx] = v[i]
                self.age[idx] = 0
                self.ptr += 1
        elif strategy == "nearest":
            sims = self._similarity(k_c, mem_c)  # (B,S)
            idxs = sims.argmax(dim=-1)
            for i, idx in enumerate(idxs.tolist()):
                self.keys_r[idx] = F.normalize((1 - self.ema) * self.keys_r[idx] + self.ema * k_c[i].real, dim=-1)
                self.keys_i[idx] = F.normalize((1 - self.ema) * self.keys_i[idx] + self.ema * k_c[i].imag, dim=-1)
                self.values[idx] = (1 - self.ema) * self.values[idx] + self.ema * v[i]
                self.age[idx] = 0
        else:
            raise ValueError(f"Unknown write strategy: {strategy}")

        self.age += 1

    def read(self, q: torch.Tensor, topk: int = 0, return_weights: bool = False):
        single = q.dim() == 1
        if single: q = q.unsqueeze(0)
        q_c = q if torch.is_complex(q) else _as_complex(q)
        q_c = self._apply_unitary(q_c)
        k_c = _as_complex(self.keys_r, self.keys_i)

        scores = self._similarity(q_c, k_c) / max(self.temperature, 1e-6)
        if topk and topk < self.slots:
            vals, idxs = scores.topk(topk, dim=-1)
            w = torch.softmax(vals, dim=-1)
            gathered = self.values[idxs]
            out = (w.unsqueeze(-1) * gathered).sum(dim=1)
            weights = w
        else:
            w = torch.softmax(scores, dim=-1)
            out = w @ self.values
            weights = w

        if single:
            out = out.squeeze(0)
            if return_weights:
                weights = weights.squeeze(0)
        return (out, weights) if return_weights else (out, None)