# MIRAS: Memory, Attentional Bias, Retention & Algorithms

> **Goal**: Implement a minimal, educational reproduction of the MIRAS framework from "It's All Connected: A Journey Through Test-Time Memorization, Attentional Bias, Retention, and Online Optimization" (Behrouz et al. 2025).

## Overview

MIRAS (meaning "Legacy" in Persian/Arabic/Turkish) is a unified framework for designing sequence models based on four fundamental design choices:

| Component | Description | Examples |
|-----------|-------------|----------|
| **Memory Architecture** | Vector, matrix, or deep MLP | Linear, 2-layer MLP |
| **Attentional Bias** | Internal objective function defining "similarity" | ‚Ñì‚ÇÇ, ‚Ñì‚Çö, Huber, Dot-Product |
| **Retention Gate** | How to balance learning new vs. retaining old | ‚Ñì‚ÇÇ, KL, Elastic Net |
| **Memory Learning Algorithm** | The optimizer used | GD, GD+momentum |

### Why MIRAS Matters

1. **Unification**: Shows that Transformers, RNNs, and SSMs are all associative memory variants
2. **Generalization**: Enables designing new architectures by mixing components
3. **Robustness**: Novel attentional biases (Huber, ‚Ñì‚Çö) handle outliers better
4. **Memory Management**: Novel retention gates provide better forgetting mechanisms

### Models Covered

| Model | Memory | Attentional Bias | Retention | Algorithm |
|-------|--------|-----------------|-----------|-----------|
| Linear Attention | Matrix | Dot-Product | - | GD |
| DeltaNet | Matrix | ‚Ñì‚ÇÇ | - | GD |
| Titans-LMM | k-layer MLP | ‚Ñì‚ÇÇ | ‚Ñì‚ÇÇ | GD+Momentum |
| **Moneta** | 2-layer MLP | ‚Ñì‚Çö | ‚Ñìq | GD |
| **Yaad** | 2-layer MLP | Huber | ‚Ñì‚ÇÇ | GD |
| **Memora** | 2-layer MLP | ‚Ñì‚ÇÇ | KL | GD |

Let's build this step by step! üöÄ

## Part 1: Theoretical Foundation

### Core Concept: Associative Memory with Attentional Bias

The fundamental insight of MIRAS is that most sequence models can be viewed as **associative memory modules** that learn a mapping from keys to values:

$$M^* = \arg\min_M \mathcal{L}(M(K); V)$$

Where:
- $K \subseteq \mathbb{R}^{d_k}$ are keys (projections of input)
- $V \subseteq \mathbb{R}^{d_v}$ are values (projections of input)
- $\mathcal{L}$ is the **attentional bias** (internal objective function)
- $M$ is the memory module (parameterized or non-parametric)

### The Learning-Retaining Viewpoint

$$W_t = \arg\min_W \tilde{\ell}_t(W; k_t, v_t) + \text{Ret}_t(W, W_{t-1})$$

- **First term**: Learn from new data
- **Second term**: Retain previously learned knowledge

This is a simple yet powerful formulation!

## Part 2: Setup & Data

Let's start with imports and loading our dataset.

In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import requests
from typing import Optional, Tuple, Callable
import gc

torch.cuda.empty_cache()
gc.collect()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


In [2]:
dataset_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

def download_dataset(url: str, filename: str = "dataset.txt") -> str:
    response = requests.get(url)
    if response.status_code == 200:
        with open(filename, "wb") as f:
            f.write(response.content)
        print(f"Downloaded {len(response.content)} bytes to {filename}")
        return response.text
    else:
        raise RuntimeError(f"Failed to download: {response.status_code}")

text = download_dataset(dataset_url)
print(f"Dataset length: {len(text)} characters")
print(f"First 200 chars:\n{text[:200]}")

Downloaded 1115394 bytes to dataset.txt
Dataset length: 1115394 characters
First 200 chars:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Vocabulary size: {vocab_size}")
print(f"Characters: {''.join(chars)}")

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

print(f"\nExample encoding: 'hello' -> {encode('hello')}")
print(f"Example decoding: {encode('hello')} -> '{decode(encode('hello'))}'")

data = torch.tensor(encode(text), dtype=torch.long)
print(f"\nData tensor shape: {data.shape}")

Vocabulary size: 65
Characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz

Example encoding: 'hello' -> [46, 43, 50, 50, 53]
Example decoding: [46, 43, 50, 50, 53] -> 'hello'

Data tensor shape: torch.Size([1115394])


In [42]:
batch_size = 32      # Reduced due to sequential memory processing
block_size = 64      # Shorter context for memory efficiency
n_embd = 128         # Larger embeddings (transformer uses 384)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

print(f"Train data: {len(train_data)} tokens")
print(f"Val data: {len(val_data)} tokens")

def get_batch(split: str) -> Tuple[torch.Tensor, torch.Tensor]:
    data_split = train_data if split == 'train' else val_data
    ix = torch.randint(len(data_split) - block_size, (batch_size,))
    x = torch.stack([data_split[i:i+block_size] for i in ix])
    y = torch.stack([data_split[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

xb, yb = get_batch('train')
print(f"\nBatch shapes: x={xb.shape}, y={yb.shape}")

Train data: 1003854 tokens
Val data: 111540 tokens

Batch shapes: x=torch.Size([32, 64]), y=torch.Size([32, 64])


## Part 3: Building Blocks

### Key-Value Projections

The first building block is projecting input embeddings into keys, values, and queries. This is shared across all memory architectures.

In [28]:
class KeyValueProjection(nn.Module):
    """Projects input embeddings into keys, values, and queries.

    This is the standard projection used in attention mechanisms.
    In MIRAS, K and V are used for memory updates, Q is used for retrieval.
    """
    def __init__(self, d_in: int, d_out: int):
        super().__init__()
        self.W_K = nn.Linear(d_in, d_out, bias=False)
        self.W_V = nn.Linear(d_in, d_out, bias=False)
        self.W_Q = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self.W_K(x), self.W_V(x), self.W_Q(x)

kv_proj = KeyValueProjection(n_embd, n_embd).to(device)
test_x = torch.randn(2, 8, n_embd, device=device)
k, v, q = kv_proj(test_x)
print(f"Input shape: {test_x.shape}")
print(f"K, V, Q shapes: {k.shape}, {v.shape}, {q.shape}")

Input shape: torch.Size([2, 8, 128])
K, V, Q shapes: torch.Size([2, 8, 128]), torch.Size([2, 8, 128]), torch.Size([2, 8, 128])


### Linear Memory: Hebbian Rule (Linear Attention)

The simplest associative memory uses the **Hebbian update**:

$$M_t = \alpha M_{t-1} + v_t k_t^\top$$

This is equivalent to **Linear Attention** ‚Äî it simply accumulates key-value associations.

- **Memory**: $M \in \mathbb{R}^{d \times d}$ (matrix)
- **Attentional Bias**: Dot-product similarity
- **Retention**: Exponential decay ($\alpha$)

In [29]:
class LinearMemoryHebbian(nn.Module):
    """Simplest associative memory - equivalent to Linear Attention.

    Update rule: M_t = Œ± * M_{t-1} + v_t ‚äó k_t
    Retrieval:   y_t = M_t @ q_t

    This is a "write everything, forget slowly" approach.
    """
    def __init__(self, d: int):
        super().__init__()
        self.d = d

    def forward(self, keys: torch.Tensor, values: torch.Tensor,
                queries: torch.Tensor, alpha: float = 0.9) -> torch.Tensor:
        B, T, D = keys.shape
        M = torch.zeros(B, D, D, device=keys.device)
        outputs = []

        for t in range(T):
            k_t = keys[:, t]      # [B, D]
            v_t = values[:, t]    # [B, D]
            q_t = queries[:, t]   # [B, D]

            M = alpha * M + torch.einsum('bd,be->bde', v_t, k_t)
            y_t = torch.einsum('bde,be->bd', M, q_t)
            outputs.append(y_t)

        return torch.stack(outputs, dim=1)

mem_hebbian = LinearMemoryHebbian(n_embd).to(device)
out = mem_hebbian(k, v, q)
print(f"Hebbian memory output shape: {out.shape}")

Hebbian memory output shape: torch.Size([2, 8, 128])


### Linear Memory: Delta Rule (DeltaNet)

The **Delta Rule** is smarter: it removes the old value before writing the new one:

$$M_t = \alpha(I - \eta k_t k_t^\top) M_{t-1} + v_t k_t^\top$$

The $(I - \eta k_t k_t^\top)$ term "erases" the old association for $k_t$ before writing the new one.

**Key insight**: The learning rate $\eta$ only appears in the "erase" term, NOT in the "write" term!

In [30]:
class LinearMemoryDelta(nn.Module):
    """Delta rule memory - removes old value before writing new.

    Update rule: M_t = Œ± * (I - Œ∑ * k_t ‚äó k_t) @ M_{t-1} + v_t ‚äó k_t

    The (I - Œ∑ * k_t ‚äó k_t) term "erases" the old value associated with k_t
    before writing the new association v_t ‚äó k_t.

    This is more like "overwrite" than "accumulate".
    """
    def __init__(self, d: int):
        super().__init__()
        self.d = d

    def forward(self, keys: torch.Tensor, values: torch.Tensor,
                queries: torch.Tensor, alpha: float = 0.9, eta: float = 0.1) -> torch.Tensor:
        B, T, D = keys.shape
        M = torch.zeros(B, D, D, device=keys.device)
        I = torch.eye(D, device=keys.device).unsqueeze(0)  # [1, D, D]
        outputs = []

        for t in range(T):
            k_t = keys[:, t]      # [B, D]
            v_t = values[:, t]    # [B, D]
            q_t = queries[:, t]   # [B, D]

            kk = torch.einsum('bd,be->bde', k_t, k_t)  # [B, D, D]
            M = alpha * torch.bmm(I - eta * kk, M)     # Erase old
            M = M + torch.einsum('bd,be->bde', v_t, k_t)  # Write new (no eta!)

            y_t = torch.einsum('bde,be->bd', M, q_t)
            outputs.append(y_t)

        return torch.stack(outputs, dim=1)

mem_delta = LinearMemoryDelta(n_embd).to(device)
out = mem_delta(k, v, q)
print(f"Delta memory output shape: {out.shape}")

Delta memory output shape: torch.Size([2, 8, 128])


### Deep Memory Module (Titans/MIRAS)

Linear memory assumes linear dependencies. **Deep memory** can learn non-linear patterns.

The architecture (post-norm per MIRAS Eq. 5):
$$M(x) = x + \text{LayerNorm}(W_1 \sigma(W_2 x))$$

Where:
- $W_2 \in \mathbb{R}^{h \times d}$ projects up (expansion)
- $W_1 \in \mathbb{R}^{d \times h}$ projects down
- $\sigma$ is GELU activation

In [31]:
class DeepMemory(nn.Module):
    """2-layer MLP memory as in Titans/MIRAS (post-norm architecture).

    M(x) = x + LayerNorm(W1 @ œÉ(W2 @ x))

    This can learn non-linear key-value associations.
    """
    def __init__(self, d: int, expansion: int = 4):
        super().__init__()
        self.W2 = nn.Linear(d, d * expansion, bias=False)  # Up projection
        self.W1 = nn.Linear(d * expansion, d, bias=False)  # Down projection
        self.ln = nn.LayerNorm(d)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = F.gelu(self.W2(x))
        return x + self.ln(self.W1(h))

deep_mem = DeepMemory(n_embd).to(device)
test_input = torch.randn(2, 8, n_embd, device=device)
test_output = deep_mem(test_input)
print(f"Deep memory: {test_input.shape} -> {test_output.shape}")
print(f"Parameters: {sum(p.numel() for p in deep_mem.parameters()):,}")

Deep memory: torch.Size([2, 8, 128]) -> torch.Size([2, 8, 128])
Parameters: 131,328


## Part 4: Attentional Bias Objectives

The **attentional bias** is the internal objective function that defines what "similarity" means for the memory. Different choices lead to different behaviors:

| Objective | Formula | Properties |
|-----------|---------|------------|
| ‚Ñì‚ÇÇ (MSE) | $\frac{1}{2}\|M(k) - v\|_2^2$ | Standard, smooth |
| ‚Ñì‚Çö | $\|M(k) - v\|_p^p$ | Robust for $p < 2$ |
| Huber | Quadratic near 0, linear far | Robust to outliers |

In [32]:
def l2_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Standard ‚Ñì‚ÇÇ loss (MSE without mean reduction).

    Returns per-sample loss for gradient computation.
    """
    return 0.5 * ((pred - target) ** 2).sum(dim=-1)


def lp_loss(pred: torch.Tensor, target: torch.Tensor, p: float = 3) -> torch.Tensor:
    """‚Ñì‚Çö loss - more robust for p < 2, more sensitive for p > 2.

    p=1: Manhattan distance (most robust to outliers)
    p=2: Euclidean distance (standard)
    p=3+: More sensitive to large errors
    """
    return (torch.abs(pred - target) ** p).sum(dim=-1)


def huber_loss(pred: torch.Tensor, target: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
    """Huber loss - robust to outliers.

    Behaves like ‚Ñì‚ÇÇ for small errors, ‚Ñì‚ÇÅ for large errors.
    Delta controls the transition point.
    """
    diff = pred - target
    abs_diff = torch.abs(diff)
    quadratic = 0.5 * diff ** 2
    linear = delta * (abs_diff - 0.5 * delta)
    return torch.where(abs_diff <= delta, quadratic, linear).sum(dim=-1)

pred = torch.randn(4, n_embd, device=device)
target = torch.randn(4, n_embd, device=device)
delta = torch.ones(4, 1, device=device) * 0.5

print(f"‚Ñì‚ÇÇ loss: {l2_loss(pred, target)}")
print(f"‚Ñì‚ÇÉ loss: {lp_loss(pred, target, p=3)}")
print(f"Huber loss: {huber_loss(pred, target, delta)}")

‚Ñì‚ÇÇ loss: tensor([127.6315, 132.5224, 120.8685, 147.0225], device='cuda:0')
‚Ñì‚ÇÉ loss: tensor([641.6620, 566.5828, 533.0881, 735.3480], device='cuda:0')
Huber loss: tensor([55.2330, 61.3017, 55.1498, 61.9771], device='cuda:0')


### Gradient Computation for Memory Updates

For deep memory, we use PyTorch autograd to compute gradients. For linear memory, we can also compute closed-form gradients.

In [33]:
def compute_memory_grad(memory: nn.Module, k: torch.Tensor, v: torch.Tensor,
                        loss_fn: Callable) -> Tuple[torch.Tensor, ...]:
    """Compute gradient of loss w.r.t. memory parameters using autograd.

    Works for any memory architecture (linear or deep).
    """
    for p in memory.parameters():
        if p.grad is not None:
            p.grad.zero_()

    pred = memory(k)
    loss = loss_fn(pred, v).sum()
    grads = torch.autograd.grad(loss, memory.parameters(), create_graph=False)
    return grads


def apply_memory_update(memory: nn.Module, grads: Tuple[torch.Tensor, ...],
                        alpha: float, eta: float) -> None:
    """Apply gradient update with retention (weight decay)."""
    with torch.no_grad():
        for param, grad in zip(memory.parameters(), grads):
            param.mul_(alpha).sub_(eta * grad)


def l2_gradient_linear(M: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """Closed-form gradient of ‚Ñì‚ÇÇ loss for linear memory: (Mk - v) @ k.T"""
    pred = torch.einsum('bde,be->bd', M, k)  # M @ k
    error = pred - v
    return torch.einsum('bd,be->bde', error, k)  # error @ k.T


def lp_gradient_linear(M: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                       p: float = 3, eps: float = 1e-6) -> torch.Tensor:
    """Closed-form gradient of ‚Ñì‚Çö loss for linear memory with smooth approximations."""
    pred = torch.einsum('bde,be->bd', M, k)
    diff = pred - v
    sign_diff = torch.tanh(100 * diff)  # Smooth sign approximation
    abs_diff = torch.sqrt(diff ** 2 + eps)  # Smooth abs
    grad_coef = p * sign_diff * (abs_diff ** (p - 1))
    return torch.einsum('bd,be->bde', grad_coef, k)

M = torch.randn(2, n_embd, n_embd, device=device)
k_test = torch.randn(2, n_embd, device=device)
v_test = torch.randn(2, n_embd, device=device)

grad_l2 = l2_gradient_linear(M, k_test, v_test)
grad_l3 = lp_gradient_linear(M, k_test, v_test, p=3)
print(f"‚Ñì‚ÇÇ gradient shape: {grad_l2.shape}")
print(f"‚Ñì‚ÇÉ gradient shape: {grad_l3.shape}")

‚Ñì‚ÇÇ gradient shape: torch.Size([2, 128, 128])
‚Ñì‚ÇÉ gradient shape: torch.Size([2, 128, 128])


## Part 5: Retention Gates

**Retention** controls how much of the old memory to keep when learning new information.

| Retention | Update Form | Properties |
|-----------|-------------|------------|
| ‚Ñì‚ÇÇ | $W_t = \alpha W_{t-1} - \eta \nabla$ | Simple weight decay |
| KL | $W_t = c \cdot \text{Softmax}(\alpha \log W_{t-1} - \eta \nabla)$ | Entropy regularized |
| Elastic Net | $W_t = \text{soft\_threshold}(\lambda W_{t-1} - \zeta \nabla, \gamma)$ | Sparse forgetting |

In [34]:
def l2_retention_update(W: torch.Tensor, grad: torch.Tensor,
                        alpha: float, eta: float) -> torch.Tensor:
    """Standard ‚Ñì‚ÇÇ retention (weight decay) update.

    W_t = Œ± * W_{t-1} - Œ∑ * ‚àá‚Ñì

    The ‚Ñì‚ÇÇ regularization induces exponential decay of old weights.
    """
    return alpha * W - eta * grad


def kl_retention_update(log_W: torch.Tensor, grad: torch.Tensor,
                        alpha: float, eta: float, c: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
    """KL divergence retention (Memora) per MIRAS Equation 27.

    Works in log domain: log_W_t = Œ± * log_W_{t-1} - Œ∑ * ‚àá‚Ñì
    Then: W_t = c * Softmax(log_W_t)

    Returns both the updated log_W and the actual W.
    """
    log_W_new = alpha * log_W - eta * grad
    W_new = c * F.softmax(log_W_new, dim=-1)
    return log_W_new, W_new


def soft_threshold(z: torch.Tensor, gamma: float) -> torch.Tensor:
    """Proximal operator for ‚Ñì‚ÇÅ regularization (soft thresholding).

    This is the key operation for elastic net / sparse forgetting.
    """
    return torch.sign(z) * F.relu(torch.abs(z) - gamma)


def elastic_net_update(W: torch.Tensor, grad: torch.Tensor,
                       lambda_decay: float, zeta_lr: float, gamma_l1: float) -> torch.Tensor:
    """Elastic net update combining ‚Ñì‚ÇÇ decay with ‚Ñì‚ÇÅ sparsity.

    W_t = soft_threshold(Œª * W_{t-1} - Œ∂ * ‚àá‚Ñì, Œ≥)

    Combines feature selection (‚Ñì‚ÇÅ) with bias reduction (‚Ñì‚ÇÇ).
    """
    return soft_threshold(lambda_decay * W - zeta_lr * grad, gamma_l1)

W_test = torch.randn(4, 8, device=device)
grad_test = torch.randn(4, 8, device=device)

W_l2 = l2_retention_update(W_test, grad_test, 0.9, 0.1)
_, W_kl = kl_retention_update(W_test, grad_test, 0.9, 0.1, c=1.0)
W_elastic = elastic_net_update(W_test, grad_test, 0.9, 0.1, 0.01)

print(f"‚Ñì‚ÇÇ retention update: {W_l2.shape}")
print(f"KL retention update: {W_kl.shape}")
print(f"Elastic net update: {W_elastic.shape}")
print(f"Sparsity (elastic net zeros): {(W_elastic.abs() < 1e-6).sum().item()}/{W_elastic.numel()}")

‚Ñì‚ÇÇ retention update: torch.Size([4, 8])
KL retention update: torch.Size([4, 8])
Elastic net update: torch.Size([4, 8])
Sparsity (elastic net zeros): 0/32


## Part 6: Three Novel MIRAS Models

Now we combine the building blocks into complete models. Each model has:
- A **deep memory** (2-layer MLP)
- A specific **attentional bias** (objective function)
- A specific **retention gate** (forgetting mechanism)

### Model 1: Moneta (‚Ñì‚Çö-‚Ñìq-Moneta)

**Configuration**:
- Memory: 2-layer MLP with GELU, residual + LayerNorm
- Attentional Bias: ‚Ñì‚Çö (p=3)
- Retention: ‚Ñìq normalization (q=4)
- Algorithm: Gradient Descent

The ‚Ñìq normalization ensures memory weights don't explode:
$$W_t = \frac{A_t}{\|A_t\|^{(q-2)/q}}$$

In [35]:
class Moneta(nn.Module):
    """‚Ñì‚Çö attentional bias + ‚Ñìq retention.

    Uses a functional approach: maintains memory state explicitly
    rather than modifying nn.Module parameters in-place.

    The ‚Ñìq normalization keeps weights bounded while allowing
    ‚Ñì‚Çö gradients to shape the learning dynamics.
    """
    def __init__(self, d: int, expansion: int = 4, p: float = 3, q: float = 4):
        super().__init__()
        self.d = d
        self.expansion = expansion
        self.p = p
        self.q = q
        self.kv_proj = KeyValueProjection(d, d)

        self.W1_init = nn.Parameter(torch.randn(d, d * expansion) * 0.02)
        self.W2_init = nn.Parameter(torch.randn(d * expansion, d) * 0.02)
        self.ln = nn.LayerNorm(d)

    def memory_forward(self, x: torch.Tensor, W1: torch.Tensor, W2: torch.Tensor) -> torch.Tensor:
        """Functional memory forward pass."""
        h = F.gelu(x @ W2.transpose(-2, -1))
        return x + self.ln(h @ W1.transpose(-2, -1))

    def lq_normalize(self, A: torch.Tensor) -> torch.Tensor:
        """‚Ñìq normalization: W = A / ||A||^((q-2)/q)"""
        norm = torch.norm(A, dim=(-2, -1), keepdim=True).clamp(min=1e-8)
        return A / (norm ** ((self.q - 2) / self.q))

    def forward(self, x: torch.Tensor, alpha: float = 0.9, eta: float = 0.1) -> torch.Tensor:
        k, v, q = self.kv_proj(x)
        B, T, D = k.shape

        A1 = self.W1_init.clone().unsqueeze(0).expand(B, -1, -1).contiguous()
        A2 = self.W2_init.clone().unsqueeze(0).expand(B, -1, -1).contiguous()
        outputs = []

        with torch.enable_grad():
            for t in range(T):
                k_t, v_t, q_t = k[:, t], v[:, t], q[:, t]

                W1 = self.lq_normalize(A1)
                W2 = self.lq_normalize(A2)

                W1_leaf = W1.detach().requires_grad_(True)
                W2_leaf = W2.detach().requires_grad_(True)

                pred = self.memory_forward(k_t.unsqueeze(1), W1_leaf, W2_leaf).squeeze(1)
                loss = lp_loss(pred, v_t, self.p).sum()
                grad1, grad2 = torch.autograd.grad(loss, [W1_leaf, W2_leaf])

                A1 = alpha * A1 - eta * grad1
                A2 = alpha * A2 - eta * grad2

                y_t = self.memory_forward(q_t.unsqueeze(1), W1, W2).squeeze(1)
                outputs.append(y_t)

        return torch.stack(outputs, dim=1)

moneta = Moneta(n_embd).to(device)
test_emb = torch.randn(2, 8, n_embd, device=device)
out = moneta(test_emb)
print(f"Moneta output shape: {out.shape}")
print(f"Moneta parameters: {sum(p.numel() for p in moneta.parameters()):,}")

Moneta output shape: torch.Size([2, 8, 128])
Moneta parameters: 180,480


### Model 2: Yaad (Robust Memory with Coping)

**Yaad** (memory in Persian) uses **Huber loss** as the attentional bias, making it robust to outliers.

**Configuration**:
- Memory: 2-layer MLP with GELU, residual + LayerNorm
- Attentional Bias: Huber loss (data-dependent threshold Œ¥)
- Retention: ‚Ñì‚ÇÇ (standard weight decay)
- Algorithm: Gradient Descent

The insight: like human memory's coping mechanisms, extreme events (outliers) are processed differently than normal events.

In [36]:
class Yaad(nn.Module):
    """Huber attentional bias - robust to outliers.

    Switches between ‚Ñì‚ÇÇ and ‚Ñì‚ÇÅ gradients based on error magnitude,
    providing robustness similar to human memory's coping mechanisms.

    For small errors: behave like ‚Ñì‚ÇÇ (smooth, precise)
    For large errors: behave like ‚Ñì‚ÇÅ (robust, less sensitive)
    """
    def __init__(self, d: int, expansion: int = 4):
        super().__init__()
        self.d = d
        self.kv_proj = KeyValueProjection(d, d)
        self.delta_proj = nn.Linear(d, 1)  # Data-dependent threshold

        self.W1_init = nn.Parameter(torch.randn(d, d * expansion) * 0.02)
        self.W2_init = nn.Parameter(torch.randn(d * expansion, d) * 0.02)
        self.ln = nn.LayerNorm(d)

    def memory_forward(self, x: torch.Tensor, W1: torch.Tensor, W2: torch.Tensor) -> torch.Tensor:
        h = F.gelu(x @ W2.transpose(-2, -1))
        return x + self.ln(h @ W1.transpose(-2, -1))

    def forward(self, x: torch.Tensor, alpha: float = 0.9, eta: float = 0.1) -> torch.Tensor:
        k, v, q = self.kv_proj(x)
        B, T, D = k.shape

        W1 = self.W1_init.clone().unsqueeze(0).expand(B, -1, -1).contiguous()
        W2 = self.W2_init.clone().unsqueeze(0).expand(B, -1, -1).contiguous()
        outputs = []

        with torch.enable_grad():
            for t in range(T):
                k_t, v_t, q_t = k[:, t], v[:, t], q[:, t]

                W1_leaf = W1.detach().requires_grad_(True)
                W2_leaf = W2.detach().requires_grad_(True)

                pred = self.memory_forward(k_t.unsqueeze(1), W1_leaf, W2_leaf).squeeze(1)
                delta_t = F.softplus(self.delta_proj(x[:, t]))  # [B, 1]

                loss = huber_loss(pred, v_t, delta_t).sum()
                grad1, grad2 = torch.autograd.grad(loss, [W1_leaf, W2_leaf])

                W1 = alpha * W1 - eta * grad1
                W2 = alpha * W2 - eta * grad2

                y_t = self.memory_forward(q_t.unsqueeze(1), W1.detach(), W2.detach()).squeeze(1)
                outputs.append(y_t)

        return torch.stack(outputs, dim=1)

yaad = Yaad(n_embd).to(device)
out = yaad(test_emb)
print(f"Yaad output shape: {out.shape}")
print(f"Yaad parameters: {sum(p.numel() for p in yaad.parameters()):,}")

Yaad output shape: torch.Size([2, 8, 128])
Yaad parameters: 180,609


### Model 3: Memora (Entropy-Regularized Memory)

**Memora** uses **KL divergence** retention, keeping memory weights on a scaled probability simplex.

**Configuration**:
- Memory: 2-layer MLP with GELU, residual + LayerNorm
- Attentional Bias: ‚Ñì‚ÇÇ
- Retention: KL divergence with scaling constant $c$
- Algorithm: Gradient Descent

The update rule (MIRAS Eq. 27):
$$W_t = c \cdot \text{Softmax}(\alpha_t \log(W_{t-1}) - \eta_t \nabla\ell_2)$$

The softmax ensures numerical stability and the scaling constant $c$ controls output magnitude.

In [37]:
class Memora(nn.Module):
    """KL divergence retention - entropy-regularized memory (MIRAS Eq. 27).

    W_t = c * Softmax(Œ±_t * log(W_{t-1}) - Œ∑_t * ‚àá‚Ñì‚ÇÇ)

    Works in log-domain for numerical stability. The softmax ensures
    weights stay on a scaled probability simplex.
    """
    def __init__(self, d: int, expansion: int = 4, c: float = 1.0):
        super().__init__()
        self.d = d
        self.kv_proj = KeyValueProjection(d, d)
        self.c = nn.Parameter(torch.tensor(c))
        self.ln = nn.LayerNorm(d)

        W1_raw = torch.rand(d, d * expansion) + 0.1
        W2_raw = torch.rand(d * expansion, d) + 0.1
        self.register_buffer('W1_init', F.softmax(W1_raw, dim=-1))
        self.register_buffer('W2_init', F.softmax(W2_raw, dim=-1))

    def memory_forward(self, x: torch.Tensor, W1: torch.Tensor, W2: torch.Tensor) -> torch.Tensor:
        h = F.gelu(x @ W2.transpose(-2, -1))
        return x + self.ln(h @ W1.transpose(-2, -1))

    def forward(self, x: torch.Tensor, alpha: float = 0.9, eta: float = 0.1) -> torch.Tensor:
        k, v, q = self.kv_proj(x)
        B, T, D = k.shape

        log_W1 = torch.log(self.W1_init.clamp(min=1e-10)).unsqueeze(0).expand(B, -1, -1).contiguous()
        log_W2 = torch.log(self.W2_init.clamp(min=1e-10)).unsqueeze(0).expand(B, -1, -1).contiguous()
        outputs = []

        with torch.enable_grad():
            for t in range(T):
                k_t, v_t, q_t = k[:, t], v[:, t], q[:, t]

                W1 = self.c * F.softmax(log_W1, dim=-1)
                W2 = self.c * F.softmax(log_W2, dim=-1)

                W1_leaf = W1.detach().requires_grad_(True)
                W2_leaf = W2.detach().requires_grad_(True)

                pred = self.memory_forward(k_t.unsqueeze(1), W1_leaf, W2_leaf).squeeze(1)
                loss = l2_loss(pred, v_t).sum()
                grad1, grad2 = torch.autograd.grad(loss, [W1_leaf, W2_leaf])

                log_W1 = alpha * log_W1 - eta * grad1
                log_W2 = alpha * log_W2 - eta * grad2

                W1_out = self.c * F.softmax(log_W1.detach(), dim=-1)
                W2_out = self.c * F.softmax(log_W2.detach(), dim=-1)
                y_t = self.memory_forward(q_t.unsqueeze(1), W1_out, W2_out).squeeze(1)
                outputs.append(y_t)

        return torch.stack(outputs, dim=1)

memora = Memora(n_embd).to(device)
out = memora(test_emb)
print(f"Memora output shape: {out.shape}")
print(f"Memora parameters: {sum(p.numel() for p in memora.parameters()):,}")

Memora output shape: torch.Size([2, 8, 128])
Memora parameters: 49,409


## Part 7: Unified MIRAS Layer

Now let's create a configurable MIRAS layer that can use any combination of:
- Memory architecture (linear or deep)
- Attentional bias (‚Ñì‚ÇÇ, ‚Ñì‚Çö, Huber)
- Retention gate (‚Ñì‚ÇÇ, KL, elastic net)

This is the power of the MIRAS framework ‚Äî mix and match components!

In [38]:
class MIRASLayer(nn.Module):
    """Complete MIRAS layer with configurable components.

    This is the unified layer that can instantiate any model from the MIRAS family
    by choosing the appropriate memory, attentional bias, and retention.
    """
    def __init__(self, d: int,
                 memory_type: str = 'deep',      # 'linear', 'deep'
                 attentional_bias: str = 'l2',   # 'l2', 'lp', 'huber'
                 retention: str = 'l2',          # 'l2', 'kl', 'elastic'
                 expansion: int = 4,
                 p: float = 3,
                 q: float = 4):
        super().__init__()
        self.d = d
        self.memory_type = memory_type
        self.attentional_bias = attentional_bias
        self.retention = retention
        self.p, self.q = p, q

        self.kv_proj = KeyValueProjection(d, d)

        if memory_type == 'linear':
            self.register_buffer('M_init', torch.zeros(d, d))
        else:
            self.W1_init = nn.Parameter(torch.randn(d, d * expansion) * 0.02)
            self.W2_init = nn.Parameter(torch.randn(d * expansion, d) * 0.02)
            self.ln = nn.LayerNorm(d)

        if attentional_bias == 'huber':
            self.delta_proj = nn.Linear(d, 1)

        self.alpha = nn.Parameter(torch.ones(1) * 0.9)
        self.eta = nn.Parameter(torch.ones(1) * 0.1)

        if retention == 'kl':
            self.c = nn.Parameter(torch.ones(1))
        if retention == 'elastic':
            self.gamma = nn.Parameter(torch.ones(1) * 0.01)

    def memory_forward_deep(self, x: torch.Tensor, W1: torch.Tensor, W2: torch.Tensor) -> torch.Tensor:
        h = F.gelu(x @ W2.transpose(-2, -1))
        return x + self.ln(h @ W1.transpose(-2, -1))

    def get_loss(self, pred: torch.Tensor, target: torch.Tensor, x_t: Optional[torch.Tensor] = None) -> torch.Tensor:
        if self.attentional_bias == 'l2':
            return l2_loss(pred, target).sum()
        elif self.attentional_bias == 'lp':
            return lp_loss(pred, target, self.p).sum()
        elif self.attentional_bias == 'huber':
            delta = F.softplus(self.delta_proj(x_t))
            return huber_loss(pred, target, delta).sum()
        else:
            raise ValueError(f"Unknown attentional bias: {self.attentional_bias}")

    def apply_retention(self, W: torch.Tensor, grad: torch.Tensor, log_W: Optional[torch.Tensor] = None):
        alpha = torch.sigmoid(self.alpha)
        eta = F.softplus(self.eta)

        if self.retention == 'l2':
            return l2_retention_update(W, grad, alpha, eta), None
        elif self.retention == 'kl':
            if log_W is None:
                log_W = torch.log(W.clamp(min=1e-10))
            log_W_new, W_new = kl_retention_update(log_W, grad, alpha, eta, self.c)
            return W_new, log_W_new
        elif self.retention == 'elastic':
            return elastic_net_update(W, grad, alpha, eta, self.gamma), None
        else:
            raise ValueError(f"Unknown retention: {self.retention}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        k, v, q = self.kv_proj(x)
        B, T, D = k.shape
        outputs = []

        with torch.enable_grad():
            if self.memory_type == 'linear':
                M = self.M_init.unsqueeze(0).expand(B, -1, -1).contiguous()

                for t in range(T):
                    k_t, v_t, q_t = k[:, t], v[:, t], q[:, t]
                    M_leaf = M.detach().requires_grad_(True)
                    pred = torch.einsum('bde,be->bd', M_leaf, k_t)
                    loss = self.get_loss(pred, v_t, x[:, t] if self.attentional_bias == 'huber' else None)
                    grad = torch.autograd.grad(loss, M_leaf)[0]
                    M, _ = self.apply_retention(M, grad)
                    y_t = torch.einsum('bde,be->bd', M, q_t)
                    outputs.append(y_t)
            else:
                W1 = self.W1_init.unsqueeze(0).expand(B, -1, -1).contiguous()
                W2 = self.W2_init.unsqueeze(0).expand(B, -1, -1).contiguous()
                log_W1, log_W2 = None, None

                if self.retention == 'kl':
                    W1 = F.softmax(W1, dim=-1)
                    W2 = F.softmax(W2, dim=-1)
                    log_W1 = torch.log(W1.clamp(min=1e-10))
                    log_W2 = torch.log(W2.clamp(min=1e-10))

                for t in range(T):
                    k_t, v_t, q_t = k[:, t], v[:, t], q[:, t]

                    W1_leaf = W1.detach().requires_grad_(True)
                    W2_leaf = W2.detach().requires_grad_(True)

                    pred = self.memory_forward_deep(k_t.unsqueeze(1), W1_leaf, W2_leaf).squeeze(1)
                    loss = self.get_loss(pred, v_t, x[:, t] if self.attentional_bias == 'huber' else None)
                    grad1, grad2 = torch.autograd.grad(loss, [W1_leaf, W2_leaf])

                    W1, log_W1 = self.apply_retention(W1, grad1, log_W1)
                    W2, log_W2 = self.apply_retention(W2, grad2, log_W2)

                    y_t = self.memory_forward_deep(q_t.unsqueeze(1), W1.detach(), W2.detach()).squeeze(1)
                    outputs.append(y_t)

        return torch.stack(outputs, dim=1)

miras_l2 = MIRASLayer(n_embd, memory_type='deep', attentional_bias='l2', retention='l2').to(device)
miras_lp = MIRASLayer(n_embd, memory_type='deep', attentional_bias='lp', retention='l2').to(device)
miras_huber = MIRASLayer(n_embd, memory_type='deep', attentional_bias='huber', retention='l2').to(device)

print(f"MIRAS ‚Ñì‚ÇÇ: {miras_l2(test_emb).shape}")
print(f"MIRAS ‚Ñì‚Çö: {miras_lp(test_emb).shape}")
print(f"MIRAS Huber: {miras_huber(test_emb).shape}")

MIRAS ‚Ñì‚ÇÇ: torch.Size([2, 8, 128])
MIRAS ‚Ñì‚Çö: torch.Size([2, 8, 128])
MIRAS Huber: torch.Size([2, 8, 128])


## Part 8: Complete MIRAS Language Model

Now let's build a complete language model using our MIRAS layers. The architecture:
1. Token embedding + Position embedding
2. Stack of MIRAS layers
3. Output projection to vocabulary

In [51]:
class MIRASBlock(nn.Module):
    """A single MIRAS block: Memory layer + FFN, both with residual connections and LayerNorm."""

    def __init__(self, d_model: int, memory_type: str, attentional_bias: str, retention: str, ffn_mult: int = 4):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.memory = MIRASLayer(d_model, memory_type, attentional_bias, retention)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * ffn_mult),
            nn.GELU(),
            nn.Linear(d_model * ffn_mult, d_model),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.memory(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


class MIRASLanguageModel(nn.Module):
    """Complete language model using MIRAS memory layers with FFN."""

    def __init__(self, vocab_size: int, d_model: int, n_layers: int,
                 memory_type: str = 'deep',
                 attentional_bias: str = 'l2',
                 retention: str = 'l2',
                 block_size: int = 128):
        super().__init__()
        self.block_size = block_size

        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(block_size, d_model)

        self.layers = nn.ModuleList([
            MIRASBlock(d_model, memory_type, attentional_bias, retention)
            for _ in range(n_layers)
        ])

        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        self.token_embedding.weight = self.lm_head.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
        B, T = idx.shape

        tok_emb = self.token_embedding(idx)
        pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb

        for layer in self.layers:
            x = layer(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

    @torch.no_grad()
    def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0) -> torch.Tensor:
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

model = MIRASLanguageModel(
    vocab_size=vocab_size,
    d_model=n_embd,
    n_layers=2,          # Reduced for memory (MIRAS uses more memory per layer)
    memory_type='deep',
    attentional_bias='l2',
    retention='l2',
    block_size=block_size
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model architecture: {model}")

Model parameters: 642,180
Model architecture: MIRASLanguageModel(
  (token_embedding): Embedding(65, 128)
  (position_embedding): Embedding(64, 128)
  (layers): ModuleList(
    (0-1): 2 x MIRASBlock(
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (memory): MIRASLayer(
        (kv_proj): KeyValueProjection(
          (W_K): Linear(in_features=128, out_features=128, bias=False)
          (W_V): Linear(in_features=128, out_features=128, bias=False)
          (W_Q): Linear(in_features=128, out_features=128, bias=False)
        )
        (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ffn): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
    )
  )
  (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_

## Part 9: Training

Let's train our MIRAS model on the Shakespeare dataset!

In [48]:
@torch.no_grad()
def estimate_loss(model: nn.Module, eval_iters: int = 50) -> dict:
    """Estimate loss on train and val sets."""
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


def train_miras(model: nn.Module, max_iters: int = 1000,
                eval_interval: int = 100, learning_rate: float = 1e-3):
    """Training loop for MIRAS language model."""
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    for iter in range(max_iters):
        if iter % eval_interval == 0:
            losses = estimate_loss(model)
            print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        xb, yb = get_batch('train')
        _, loss = model(xb, yb)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    losses = estimate_loss(model)
    print(f"Final: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    return model

In [52]:
model = train_miras(model, max_iters=5000, eval_interval=500)

step 0: train loss 4.2027, val loss 4.2023
step 500: train loss 2.3533, val loss 2.3724
step 1000: train loss 2.3915, val loss 2.4062
step 1500: train loss 2.3558, val loss 2.3867
step 2000: train loss 2.3163, val loss 2.3407
step 2500: train loss 2.2865, val loss 2.3343
step 3000: train loss 2.2714, val loss 2.3246
step 3500: train loss 2.2677, val loss 2.3290
step 4000: train loss 2.2502, val loss 2.3082
step 4500: train loss 2.2372, val loss 2.3074
Final: train loss 2.2355, val loss 2.2928


In [53]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated = model.generate(context, max_new_tokens=200)
print("Generated text:")
print(decode(generated[0].tolist()))

Generated text:

Whe oelour h aravomaresh thtl! hus,
ONUM:
GOLIO:
I'to grucos
ETA:
Mofult thoupy.



akeatouge,
INGHESS:
I anrok'f le setheage
Thavin, burelf trotthit mion! ort cr
Ditherisspllucoman.


CLIO: ma to ke.


## Part 10: Experiments & Comparisons

Let's compare different MIRAS configurations to understand the impact of each design choice.

In [24]:
def compare_models(configs: list, max_iters: int = 300, eval_interval: int = 100):
    """Compare different MIRAS configurations."""
    results = {}

    for name, config in configs:
        print(f"\n{'='*50}")
        print(f"Training: {name}")
        print(f"{'='*50}")

        model = MIRASLanguageModel(
            vocab_size=vocab_size,
            d_model=n_embd,
            n_layers=2,
            block_size=block_size,
            **config
        ).to(device)

        model = train_miras(model, max_iters=max_iters, eval_interval=eval_interval)

        losses = estimate_loss(model)
        results[name] = {
            'train_loss': losses['train'].item(),
            'val_loss': losses['val'].item(),
            'params': sum(p.numel() for p in model.parameters())
        }

    print(f"\n{'='*50}")
    print("Summary")
    print(f"{'='*50}")
    for name, res in results.items():
        print(f"{name:30s} | Train: {res['train_loss']:.4f} | Val: {res['val_loss']:.4f} | Params: {res['params']:,}")

    return results

configs = [
    ("Linear Memory + ‚Ñì‚ÇÇ", {"memory_type": "linear", "attentional_bias": "l2", "retention": "l2"}),
    ("Deep Memory + ‚Ñì‚ÇÇ", {"memory_type": "deep", "attentional_bias": "l2", "retention": "l2"}),
    ("Deep Memory + ‚Ñì‚ÇÉ", {"memory_type": "deep", "attentional_bias": "lp", "retention": "l2"}),
    ("Deep Memory + Huber", {"memory_type": "deep", "attentional_bias": "huber", "retention": "l2"}),
]

In [None]:
# Uncomment to run comparison (takes a few minutes)
# results = compare_models(configs, max_iters=300)

## Part 11: Adding Momentum (Surprise Metric)

The TITANS insight: **An event that violates expectations is more memorable.**

**Surprise = gradient** of the loss w.r.t. input.

**Problem**: Momentary surprise can miss important info after a big surprise.

**Solution**: Track "past surprise" with momentum:

$$S_t = \eta_t S_{t-1} - \theta_t \nabla\ell(M_{t-1}; k_t, v_t)$$
$$M_t = (1 - \alpha_t) M_{t-1} + S_t$$

This is **gradient descent with momentum and weight decay**.

In [None]:
class MIRASLayerWithMomentum(nn.Module):
    """MIRAS layer with momentum (Titans-style surprise tracking).

    Instead of directly applying gradients, we accumulate them with momentum
    to track "past surprise" ‚Äî this helps remember important events that
    happen after a big surprise.
    """
    def __init__(self, d: int, expansion: int = 4):
        super().__init__()
        self.d = d
        self.kv_proj = KeyValueProjection(d, d)

        self.W1_init = nn.Parameter(torch.randn(d, d * expansion) * 0.02)
        self.W2_init = nn.Parameter(torch.randn(d * expansion, d) * 0.02)
        self.ln = nn.LayerNorm(d)

        self.alpha = nn.Parameter(torch.ones(1) * 0.9)
        self.eta = nn.Parameter(torch.ones(1) * 0.9)    # Momentum coefficient
        self.theta = nn.Parameter(torch.ones(1) * 0.1)  # Gradient coefficient

    def memory_forward(self, x: torch.Tensor, W1: torch.Tensor, W2: torch.Tensor) -> torch.Tensor:
        h = F.gelu(x @ W2.transpose(-2, -1))
        return x + self.ln(h @ W1.transpose(-2, -1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        k, v, q = self.kv_proj(x)
        B, T, D = k.shape

        W1 = self.W1_init.unsqueeze(0).expand(B, -1, -1).contiguous()
        W2 = self.W2_init.unsqueeze(0).expand(B, -1, -1).contiguous()

        S1 = torch.zeros_like(W1)
        S2 = torch.zeros_like(W2)

        alpha = torch.sigmoid(self.alpha)
        eta = torch.sigmoid(self.eta)
        theta = F.softplus(self.theta)

        outputs = []

        with torch.enable_grad():
            for t in range(T):
                k_t, v_t, q_t = k[:, t], v[:, t], q[:, t]

                W1_leaf = W1.detach().requires_grad_(True)
                W2_leaf = W2.detach().requires_grad_(True)

                pred = self.memory_forward(k_t.unsqueeze(1), W1_leaf, W2_leaf).squeeze(1)
                loss = l2_loss(pred, v_t).sum()
                grad1, grad2 = torch.autograd.grad(loss, [W1_leaf, W2_leaf])

                S1 = eta * S1 - theta * grad1
                S2 = eta * S2 - theta * grad2

                W1 = (1 - alpha) * W1 + S1
                W2 = (1 - alpha) * W2 + S2

                y_t = self.memory_forward(q_t.unsqueeze(1), W1.detach(), W2.detach()).squeeze(1)
                outputs.append(y_t)

        return torch.stack(outputs, dim=1)

momentum_layer = MIRASLayerWithMomentum(n_embd).to(device)
out = momentum_layer(test_emb)
print(f"MIRAS with momentum output: {out.shape}")

MIRAS with momentum output: torch.Size([2, 8, 64])


## Part 12: Parallelizable Training (Chunked)

The recurrent updates above are sequential (slow on GPU). We can parallelize within **chunks**:

1. Split sequence into chunks of size $b$ (e.g., 16 or 64)
2. Within each chunk, use the **same** memory state for all gradients
3. Accumulate gradients in parallel, then apply

This trades off some accuracy for speed ‚Äî gradients within a chunk don't see each other's updates.

In [26]:
class ChunkedLinearMemory(nn.Module):
    """Chunked parallel linear memory with Delta rule.

    Within each chunk, we:
    1. Use the same memory state M_0 for computing all gradients
    2. Accumulate weighted gradients in parallel
    3. Update memory at chunk boundaries

    This is faster on GPUs at the cost of some approximation.
    """
    def __init__(self, d: int, chunk_size: int = 16):
        super().__init__()
        self.d = d
        self.chunk_size = chunk_size
        self.kv_proj = KeyValueProjection(d, d)

        self.alpha = nn.Parameter(torch.ones(1) * 0.9)
        self.eta = nn.Parameter(torch.ones(1) * 0.1)

    def process_chunk(self, M: torch.Tensor, k_chunk: torch.Tensor,
                      v_chunk: torch.Tensor, q_chunk: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Process a chunk in parallel."""
        B, C, D = k_chunk.shape
        alpha = torch.sigmoid(self.alpha)
        eta = F.softplus(self.eta)

        pred_all = torch.einsum('bde,bce->bcd', M, k_chunk)
        errors = pred_all - v_chunk

        decay_weights = alpha ** torch.arange(C, 0, -1, device=M.device).float()
        decay_weights = decay_weights.view(1, -1, 1, 1)

        grads = torch.einsum('bcd,bce->bcde', errors, k_chunk)
        weighted_grads = (grads * decay_weights.squeeze(-1).unsqueeze(-1)).sum(dim=1)

        M_new = (alpha ** C) * M - eta * weighted_grads

        outputs = torch.einsum('bde,bce->bcd', M, q_chunk)

        return M_new, outputs

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        k, v, q = self.kv_proj(x)
        B, T, D = k.shape

        M = torch.zeros(B, D, D, device=x.device)

        n_chunks = (T + self.chunk_size - 1) // self.chunk_size
        all_outputs = []

        for i in range(n_chunks):
            start = i * self.chunk_size
            end = min(start + self.chunk_size, T)

            k_chunk = k[:, start:end]
            v_chunk = v[:, start:end]
            q_chunk = q[:, start:end]

            M, outputs = self.process_chunk(M, k_chunk, v_chunk, q_chunk)
            all_outputs.append(outputs)

        return torch.cat(all_outputs, dim=1)

chunked_mem = ChunkedLinearMemory(n_embd, chunk_size=8).to(device)
out = chunked_mem(test_emb)
print(f"Chunked memory output: {out.shape}")

Chunked memory output: torch.Size([2, 8, 64])


## Part 13: Key Insights & Summary

### What We Built

We implemented the MIRAS framework from scratch, covering:

| Component | Implementations |
|-----------|----------------|
| **Memory** | Linear (Hebbian, Delta), Deep MLP |
| **Attentional Bias** | ‚Ñì‚ÇÇ, ‚Ñì‚Çö, Huber |
| **Retention** | ‚Ñì‚ÇÇ decay, KL divergence, Elastic Net |
| **Algorithm** | GD, GD+Momentum |

### Key Equations

**Hebbian (Linear Attention)**:
$$M_t = \alpha M_{t-1} + v_t k_t^\top$$

**Delta Rule**:
$$M_t = \alpha(I - \eta k_t k_t^\top) M_{t-1} + v_t k_t^\top$$

**Titans/MIRAS with Momentum**:
$$S_t = \eta_t S_{t-1} - \theta_t \nabla\ell(M_{t-1}; x_t)$$
$$M_t = (1 - \alpha_t) M_{t-1} + S_t$$

**Memora (KL Retention)**:
$$W_t = c \cdot \text{Softmax}(\alpha_t \log(W_{t-1}) - \eta_t \nabla\ell_2)$$

### The MIRAS Unification

The key insight is that **most sequence models are associative memory variants**:

| Model | What it "is" |
|-------|-------------|
| Linear Attention | Hebbian memory + dot-product similarity |
| DeltaNet | Linear memory + ‚Ñì‚ÇÇ similarity |
| Mamba | Hebbian + structured decay |
| Transformer | Non-parametric memory (stores all KV pairs) |
| TTT-Linear/MLP | Trainable memory + ‚Ñì‚ÇÇ at test time |
| Titans | Deep memory + momentum + ‚Ñì‚ÇÇ |
| **Moneta** | Deep memory + ‚Ñì‚Çö + ‚Ñìq normalization |
| **Yaad** | Deep memory + Huber (robust) |
| **Memora** | Deep memory + KL retention |

### Next Steps

1. **Scale up**: Larger models, more data
2. **Hybrid architectures**: Combine MIRAS with attention
3. **Efficiency**: CUDA kernels for parallel scans
4. **Novel biases**: Explore other loss functions
5. **Applications**: Long-context tasks, memory-intensive problems

## Glossary

| Term | Definition |
|------|------------|
| **Attentional Bias** | Internal objective function defining "similarity" in memory |
| **Retention Gate** | Mechanism balancing new learning vs. old retention |
| **Momentary Surprise** | Gradient at current timestep |
| **Past Surprise** | Momentum carrying surprise across tokens |
| **Hebbian Rule** | Additive memory update: M += vk^T |
| **Delta Rule** | Replacement memory update: removes old before adding new |
| **MIRAS** | Memory, Attentional Bias, Retention, Algorithm, Sequence model |

## References

1. Behrouz et al. (2025). "It's All Connected: A Journey Through Test-Time Memorization, Attentional Bias, Retention, and Online Optimization" (MIRAS paper)
2. Behrouz et al. (2024). "Titans: Learning to Memorize at Test Time"
3. Yang et al. (2024). "Gated Delta Networks"
4. Sun et al. (2024). "TTT: Learning to (learn at test time)"
5. Katharopoulos et al. (2020). "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"