In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

The Core Update Rule (Eq. 28 & 29)

In [2]:
def hope_update_rule(W, x, grad_loss, eta):
    """
    Implements the HOPE optimizer update rule (Eq 28/29 in the paper).
    
    W_{t+1} = W_t(I - x_t x_t^T) - eta * grad_L
    
    Args:
        W: Current weight state (Batch, Out_Dim, In_Dim)
        x: Input vector (Batch, In_Dim, 1)
        grad_loss: Gradient of the loss w.r.t W (or simplified: error * x.T)
        eta: Learning rate
    """
    # Term 1: W_t (I - x x^T) -> Projection / Forgetting term
    # x @ x.transpose: (B, In, 1) @ (B, 1, In) -> (B, In, In)
    # I - xxT: (B, In, In)
    I = torch.eye(W.shape[-1], device=W.device).unsqueeze(0).expand(W.shape[0], -1, -1)
    
    # Paper suggests x should be normalized for stability in this update
    x_norm = x / (torch.norm(x, dim=1, keepdim=True) + 1e-8)
    projection_matrix = I - torch.bmm(x_norm, x_norm.transpose(1, 2))
    
    term_1 = torch.bmm(W, projection_matrix)
    
    # Term 2: Gradient update
    term_2 = eta * grad_loss
    
    W_new = term_1 - term_2
    return W_new

 Self-Referential Memory (Attention Replacement)

In [3]:
class HOPERecurrentMemory(nn.Module):
    """
    The 'Working Memory' of HOPE. 
    It learns to compress context (Keys -> Values) into a weight matrix M_t
    using the nested update rule, then answers Queries.
    """
    def __init__(self, dim, head_dim, learning_rate=0.1):
        super().__init__()
        self.dim = dim
        self.head_dim = head_dim
        self.lr = learning_rate
        
        # Projections (standard like Transformer)
        self.W_q = nn.Linear(dim, head_dim, bias=False)
        self.W_k = nn.Linear(dim, head_dim, bias=False)
        self.W_v = nn.Linear(dim, head_dim, bias=False)
        self.W_o = nn.Linear(head_dim, dim, bias=False)
        
        # Norms for stability in recurrent updates
        self.ln_k = nn.LayerNorm(head_dim)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        q = self.W_q(x) # (B, T, H)
        k = self.ln_k(self.W_k(x)) # (B, T, H)
        v = self.W_v(x) # (B, T, H)
        
        # Initialize Memory State M_0 (Batch, Head, Head)
        # M maps Keys (Head) -> Values (Head)
        M = torch.zeros(batch_size, self.head_dim, self.head_dim, device=x.device)
        
        outputs = []
        
        # Sequential processing (Recurrent view of Linear Attention)
        # Note: In production, this can be parallelized via chunking/CUDA (like Titans/Mamba)
        for t in range(seq_len):
            k_t = k[:, t, :].unsqueeze(2) # (B, H, 1)
            v_t = v[:, t, :].unsqueeze(2) # (B, H, 1)
            q_t = q[:, t, :].unsqueeze(2) # (B, H, 1)
            
            # 1. Prediction/Loss Calculation (Implicit)
            # We want M_t to map k_t -> v_t. 
            # Error signal: (M_t * k_t - v_t)
            pred_v = torch.bmm(M, k_t)
            error = pred_v - v_t
            
            # Gradient of MSE = 2 * error * k_t.T
            grad_loss = torch.bmm(error, k_t.transpose(1, 2))
            
            # 2. Update Memory M_t -> M_{t+1} using HOPE Rule (Eq 29)
            M = hope_update_rule(M, k_t, grad_loss, self.lr)
            
            # 3. Compute Output using Updated Memory
            # y_t = M_{t+1} * q_t
            y_t = torch.bmm(M, q_t).squeeze(2)
            outputs.append(y_t)
            
        y = torch.stack(outputs, dim=1) # (B, T, H)
        return self.W_o(y)

Continuum Memory System (CMS)

In [4]:
class CMSLayer(nn.Module):
    """
    A single layer of the Continuum Memory System.
    Concept: It's an MLP that holds state and updates its own weights 
    based on a local surprise signal at a specific frequency.
    """
    def __init__(self, dim, hidden_dim, update_freq=1, learning_rate=0.01):
        super().__init__()
        self.dim = dim
        self.update_freq = update_freq
        self.lr = learning_rate
        
        # We hold weights as buffers so we can manually update them per-batch if needed,
        # but for this demo, we make them Parameters that update via a custom logic.
        # To simulate 'fast weights' per sample is expensive, so we share weights 
        # across the batch but update them temporally during the forward pass.
        self.w1 = nn.Parameter(torch.randn(hidden_dim, dim) * 0.02)
        self.w2 = nn.Parameter(torch.randn(dim, hidden_dim) * 0.02)
        
        self.act = nn.SiLU()
        self.norm = nn.LayerNorm(dim)

    def forward_computation(self, x, w1, w2):
        # x: (B, D)
        h = F.linear(x, w1)
        h = self.act(h)
        return F.linear(h, w2)

    def forward(self, x, global_step=0):
        """
        x: (Batch, Seq, Dim)
        """
        batch_size, seq_len, dim = x.shape
        outputs = []
        
        # Clone current weights for temporary modification (Plasticity)
        # In a full implementation, these would be stateful per batch.
        curr_w1 = self.w1.clone()
        curr_w2 = self.w2.clone()
        
        for t in range(seq_len):
            input_t = x[:, t, :]
            
            # Standard FFN pass
            out_t = self.forward_computation(input_t, curr_w1, curr_w2)
            outputs.append(out_t)
            
            # CMS Update Logic (Eq 31)
            # Check frequency
            current_time = global_step + t
            if current_time % self.update_freq == 0:
                # Calculate "Surprise" / Gradient
                # Here we use a self-supervised reconstruction proxy: 
                # Ideally, CMS tries to predict input_t or reconstruct features.
                # For this demo, we treat the FFN output as a prediction of the input (Autoencoder-like)
                # to generate a gradient signal.
                
                loss = F.mse_loss(out_t, input_t) # Simple local objective
                
                # Manual Gradient (approximate for demo speed)
                grad_w1 = torch.autograd.grad(loss, curr_w1, retain_graph=True)[0]
                grad_w2 = torch.autograd.grad(loss, curr_w2, retain_graph=True)[0]
                
                # Apply update (SGD)
                curr_w1 = curr_w1 - self.lr * grad_w1
                curr_w2 = curr_w2 - self.lr * grad_w2
                
                # Detach to prevent VRAM explosion (Truncated BPTT equivalent)
                curr_w1 = curr_w1.detach().requires_grad_(True)
                curr_w2 = curr_w2.detach().requires_grad_(True)

        return torch.stack(outputs, dim=1) + x # Residual connection

class ContinuumMemorySystem(nn.Module):
    """
    Eq 30: Nested MLPs with different frequencies.
    Input -> Fast MLP -> Medium MLP -> Slow MLP -> Output
    """
    def __init__(self, dim, expansion_factor=4):
        super().__init__()
        hidden = dim * expansion_factor
        
        # Level 1: Fast (Updates every step)
        self.fast_mlp = CMSLayer(dim, hidden, update_freq=1, learning_rate=0.1)
        
        # Level 2: Slow (Updates every 16 steps)
        # Representing "Long-term" knowledge storage
        self.slow_mlp = CMSLayer(dim, hidden, update_freq=16, learning_rate=0.01)
        
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        # Chain of thoughts / computations
        h = self.fast_mlp(x)
        out = self.slow_mlp(h)
        return self.norm(out)

Full HOPE Block and Model

In [5]:
class HOPEBlock(nn.Module):
    def __init__(self, dim, head_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.memory = HOPERecurrentMemory(dim, head_dim)
        
        self.norm2 = nn.LayerNorm(dim)
        self.cms = ContinuumMemorySystem(dim)
        
    def forward(self, x):
        # 1. Working Memory (Attention-like)
        h = x + self.memory(self.norm1(x))
        
        # 2. Continuum Memory (FFN-like)
        out = h + self.cms(self.norm2(h))
        return out

class HOPEModel(nn.Module):
    def __init__(self, vocab_size, dim, depth, head_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        self.layers = nn.ModuleList([
            HOPEBlock(dim, head_dim) for _ in range(depth)
        ])
        self.norm_f = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, vocab_size)
        
    def forward(self, input_ids):
        x = self.embedding(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm_f(x)
        logits = self.head(x)
        return logits

Toy Experiment: Associative Recall

In [7]:
def toy_training_loop():
    # Config
    vocab_size = 100
    dim = 64
    depth = 2
    head_dim = 16
    seq_len = 32
    batch_size = 4
    
    # Init Model
    model = HOPEModel(vocab_size, dim, depth, head_dim)
    
    # HOPE introduces internal optimization loops. 
    # The outer optimizer (Meta-learner) trains the initialization 
    # and the learning rates of the inner modules.
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    
    print("Running HOPE Toy Task...")
    
    # Dummy Data: Copy task (Input: A B C ... -> Target: A B C ...)
    inputs = torch.randint(0, vocab_size, (batch_size, seq_len))
    targets = torch.roll(inputs, shifts=-1, dims=1) # Next token prediction
    
    model.train()
    
    for step in range(100):
        optimizer.zero_grad()
        
        # Forward pass runs the internal Nested Learning (inner optimizers)
        logits = model(inputs)
        
        # Loss calculation (Standard CE)
        loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
        
        loss.backward()
        optimizer.step()
        
        print(f"Step {step+1}, Loss: {loss.item():.4f}")
        
    print("Done. Model runs successfully with nested optimization logic.")

if __name__ == "__main__":
    toy_training_loop()

Running HOPE Toy Task...
Step 1, Loss: 4.7313
Step 2, Loss: 4.5962
Step 3, Loss: 4.4845
Step 4, Loss: 4.3810
Step 5, Loss: 4.2878
Step 6, Loss: 4.2005
Step 7, Loss: 4.1187
Step 8, Loss: 4.0420
Step 9, Loss: 3.9680
Step 10, Loss: 3.9103
Step 11, Loss: 3.8326
Step 12, Loss: 3.7665
Step 13, Loss: 3.7046
Step 14, Loss: 3.6426
Step 15, Loss: 3.5780
Step 16, Loss: 3.5302
Step 17, Loss: 3.4673
Step 18, Loss: 3.4019
Step 19, Loss: 3.3474
Step 20, Loss: 3.2902
Step 21, Loss: 3.2318
Step 22, Loss: 3.1934
Step 23, Loss: 3.1251
Step 24, Loss: 3.0862
Step 25, Loss: 3.0286
Step 26, Loss: 2.9789
Step 27, Loss: 2.9412
Step 28, Loss: 2.9026
Step 29, Loss: 2.8404
Step 30, Loss: 2.7853
Step 31, Loss: 2.7415
Step 32, Loss: 2.6956
Step 33, Loss: 2.6426
Step 34, Loss: 2.6001
Step 35, Loss: 2.5486
Step 36, Loss: 2.4946
Step 37, Loss: 2.4490
Step 38, Loss: 2.4068
Step 39, Loss: 2.3586
Step 40, Loss: 2.3082
Step 41, Loss: 2.2733
Step 42, Loss: 2.2182
Step 43, Loss: 2.1777
Step 44, Loss: 2.1352
Step 45, Loss: 2

Loss after 100 iters. 4.7313-->0.5931