In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Union


class DepthwiseSeparableConv1d(nn.Module):
    """
    Depthwise separable 1D convolution as used in the Titans architecture.
    """
    def __init__(self, in_dim: int, kernel_size: int = 3, bias: bool = True):
        super().__init__()
        self.depthwise = nn.Conv1d(
            in_channels=in_dim,
            out_channels=in_dim,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            groups=in_dim,
            bias=bias
        )
        self.pointwise = nn.Conv1d(
            in_channels=in_dim,
            out_channels=in_dim,
            kernel_size=1,
            bias=bias
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Input shape: [batch_size, seq_len, dim]
        # Required shape for conv1d: [batch_size, dim, seq_len]
        x = x.transpose(1, 2)
        x = self.depthwise(x)
        x = self.pointwise(x)
        # Return to original shape
        x = x.transpose(1, 2)
        return x


class NeuralLongTermMemory(nn.Module):
    """
    Neural Long-Term Memory module as described in Section 3.1.
    This module learns to memorize at test time using an MLP architecture.
    """
    def __init__(
        self,
        dim: int,
        memory_depth: int = 2,
        hidden_dim: int = None,
        activation: nn.Module = nn.SiLU(),
    ):
        super().__init__()
        
        if hidden_dim is None:
            hidden_dim = dim * 4
        
        self.dim = dim
        self.memory_depth = memory_depth
        
        # Create MLP layers
        self.layers = nn.ModuleList()
        
        # First layer (input -> hidden)
        self.layers.append(nn.Linear(dim, hidden_dim))
        self.layer_norms = nn.ModuleList([nn.LayerNorm(hidden_dim)])
        
        # Hidden layers
        for _ in range(memory_depth - 2):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.layer_norms.append(nn.LayerNorm(hidden_dim))
        
        # Output layer (hidden -> output)
        self.layers.append(nn.Linear(hidden_dim, dim))
        
        self.activation = activation
        
        # Parameter networks for alpha, theta, and eta (forgetting and surprise mechanisms)
        self.alpha_net = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
        self.theta_net = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
        self.eta_net = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
        
        # Initialize past surprise
        self.past_surprise = None
    
    def reset_state(self, batch_size: int = 1, device: torch.device = torch.device('cpu')):
        """Reset the memory state."""
        self.past_surprise = None
        
    def compute_gradient(self, mem_weights: List[torch.Tensor], key: torch.Tensor, 
                          value: torch.Tensor) -> List[torch.Tensor]:
        """
        Compute the gradient of the loss with respect to the memory weights.
        Loss = ||M(key) - value||^2
        """
        # Forward pass through the memory
        with torch.enable_grad():
            # Create copies of weights that track gradients
            weight_copies = [w.detach().clone().requires_grad_(True) for w in mem_weights]
            
            # Run forward pass
            x = key
            for i, layer in enumerate(self.layers[:-1]):
                # Replace weights with our copies
                weight = weight_copies[i]
                bias = layer.bias if layer.bias is not None else None
                
                x = F.linear(x, weight, bias)
                x = self.layer_norms[i](x)
                x = self.activation(x)
            
            # Output layer
            weight = weight_copies[-1]
            bias = self.layers[-1].bias if self.layers[-1].bias is not None else None
            output = F.linear(x, weight, bias)
            
            # Compute loss
            loss = F.mse_loss(output, value, reduction='sum')
            
            # Compute gradients
            grads = torch.autograd.grad(loss, weight_copies)
            
        return grads

    def forward_memory(self, key: torch.Tensor) -> torch.Tensor:
        """Forward pass through the memory without updating weights."""
        x = key
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            x = self.layer_norms[i](x)
            x = self.activation(x)
        
        output = self.layers[-1](x)
        return output
    
    def update_memory(self, key: torch.Tensor, value: torch.Tensor) -> None:
        """Update the memory weights based on the current input."""
        # Extract memory weights
        mem_weights = []
        for layer in self.layers:
            mem_weights.append(layer.weight)
        
        # Compute adaptive parameters
        alpha = self.alpha_net(key)  # Forgetting rate
        theta = self.theta_net(key)  # Learning rate
        eta = self.eta_net(key)      # Past surprise decay
        
        # Compute gradients (surprise)
        momentary_surprise = self.compute_gradient(mem_weights, key, value)
        
        # Update past surprise or initialize it
        if self.past_surprise is None:
            self.past_surprise = [torch.zeros_like(grad) for grad in momentary_surprise]
        
        # Update memory weights using Equation 13 and 14 from the paper
        for i, (weight, grad) in enumerate(zip(mem_weights, momentary_surprise)):
            # Update past surprise with decay (Equation 14)
            self.past_surprise[i] = eta * self.past_surprise[i] - theta * grad
            
            # Update memory weights (Equation 13)
            new_weight = (1 - alpha) * weight + self.past_surprise[i]
            mem_weights[i].data.copy_(new_weight.data)
    
    def forward(self, query: torch.Tensor, key: torch.Tensor = None, value: torch.Tensor = None, 
                update: bool = True) -> torch.Tensor:
        """
        Forward pass through the memory module.
        
        Args:
            query: Input tensor to retrieve memory
            key: Key tensor for updating memory (optional)
            value: Value tensor for updating memory (optional)
            update: Whether to update the memory weights
            
        Returns:
            Memory output for the given query
        """
        # If key and value are provided and update is True, update the memory
        if key is not None and value is not None and update:
            self.update_memory(key, value)
        
        # Retrieve from memory
        output = self.forward_memory(query)
        
        return output


class TitansAttention(nn.Module):
    """
    Causal self-attention module for the Titans model.
    """
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        dropout: float = 0.0,
        conv_kernel_size: int = 3,
    ):
        super().__init__()
        
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        
        # Projection matrices
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        
        # Convolutional layers for Q, K, V
        self.q_conv = DepthwiseSeparableConv1d(dim, kernel_size=conv_kernel_size)
        self.k_conv = DepthwiseSeparableConv1d(dim, kernel_size=conv_kernel_size)
        self.v_conv = DepthwiseSeparableConv1d(dim, kernel_size=conv_kernel_size)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape
        
        # Projection with convolutional components
        q = self.q_proj(x)
        q = self.q_conv(q)
        
        k = self.k_proj(x)
        k = self.k_conv(k)
        
        v = self.v_proj(x)
        v = self.v_conv(v)
        
        # Reshape for multi-head attention
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scale queries
        q = q / math.sqrt(self.head_dim)
        
        # Compute attention scores
        attn_scores = torch.matmul(q, k.transpose(-2, -1))
        
        # Apply causal mask if none is provided
        if mask is None:
            # Create causal mask (lower triangular)
            mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device))
            # Convert to boolean mask with appropriate dimensions
            mask = mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len]
        
        # Apply mask
        attn_scores = attn_scores.masked_fill(~mask, -1e9)
        
        # Softmax and dropout
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention weights
        context = torch.matmul(attn_weights, v)
        
        # Reshape back
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.dim)
        
        # Final projection
        output = self.out_proj(context)
        
        return output


class TitansMAC(nn.Module):
    """
    Memory as a Context (MAC) architecture for Titans model, as described in Section 4.1.
    """
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        memory_depth: int = 2,
        dropout: float = 0.0,
        num_persistent_tokens: int = 8,
        segment_size: int = 512,
        conv_kernel_size: int = 3,
    ):
        super().__init__()
        
        self.dim = dim
        self.segment_size = segment_size
        self.num_persistent_tokens = num_persistent_tokens
        
        # Persistent memory (learnable tokens)
        self.persistent_memory = nn.Parameter(torch.randn(num_persistent_tokens, dim))
        
        # Neural long-term memory
        self.long_term_memory = NeuralLongTermMemory(dim, memory_depth=memory_depth)
        
        # Attention module
        self.attention = TitansAttention(
            dim=dim,
            num_heads=num_heads,
            dropout=dropout,
            conv_kernel_size=conv_kernel_size,
        )
        
        # Projection matrices
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        
        # Layer normalization
        self.layer_norm1 = nn.LayerNorm(dim)
        self.layer_norm2 = nn.LayerNorm(dim)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.SiLU(),
            nn.Linear(4 * dim, dim),
            nn.Dropout(dropout),
        )
        
        # Gating mechanism for combining memory outputs
        self.gate = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
            nn.Sigmoid(),
        )
        
    def reset_memory(self, batch_size: int = 1, device: torch.device = torch.device('cpu')):
        """Reset the memory state."""
        self.long_term_memory.reset_state(batch_size, device)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape
        device = x.device
        
        # Process in segments
        outputs = []
        
        # Expand persistent memory for each batch element
        persistent_tokens = self.persistent_memory.unsqueeze(0).expand(batch_size, -1, -1)
        
        # Split input into segments
        for i in range(0, seq_len, self.segment_size):
            segment = x[:, i:i+self.segment_size, :]
            
            if segment.size(1) == 0:
                break
                
            # Project segment for query, key, and value
            q = self.q_proj(segment)
            k = self.k_proj(segment)
            v = self.v_proj(segment)
            
            # Retrieve from long-term memory
            memory_output = self.long_term_memory(q)
            
            # Concatenate persistent memory, memory output, and current segment
            context = torch.cat([
                persistent_tokens,
                memory_output,
                segment
            ], dim=1)
            
            # Apply attention to the full context
            attn_output = self.attention(context)
            
            # Extract only the segment portion
            segment_attn_output = attn_output[:, self.num_persistent_tokens + memory_output.size(1):, :]
            
            # Update long-term memory with the attention output
            for j in range(segment.size(1)):
                step_k = k[:, j, :]
                step_v = v[:, j, :]
                self.long_term_memory(
                    query=None,
                    key=step_k, 
                    value=step_v, 
                    update=True
                )
            
            # Retrieve updated memory for final output
            final_memory_output = self.long_term_memory(segment_attn_output)
            
            # Combine outputs with gating mechanism
            combined_features = torch.cat([segment_attn_output, final_memory_output], dim=-1)
            gate_values = self.gate(combined_features)
            
            # Apply gate and residual connection
            output = segment + segment_attn_output * gate_values + final_memory_output * (1 - gate_values)
            
            # Layer norm and feed-forward
            norm_output = self.layer_norm1(output)
            ffn_output = self.ffn(norm_output)
            
            # Final output with residual
            final_output = output + ffn_output
            final_output = self.layer_norm2(final_output)
            
            outputs.append(final_output)
        
        # Concatenate segment outputs
        output = torch.cat(outputs, dim=1)
        
        return output


class TitansModel(nn.Module):
    """
    Complete Titans model with configurable number of layers.
    """
    def __init__(
        self,
        dim: int,
        num_layers: int = 6,
        num_heads: int = 8,
        memory_depth: int = 2,
        dropout: float = 0.0,
        num_persistent_tokens: int = 8,
        segment_size: int = 512,
        conv_kernel_size: int = 3,
    ):
        super().__init__()
        
        self.dim = dim
        self.num_layers = num_layers
        
        # Token embedding
        self.token_embedding = nn.Embedding(50304, dim)  # Default vocab size for compatibility
        
        # Position embedding (if needed)
        self.pos_embedding = nn.Embedding(8192, dim)  # Large enough for most contexts
        
        # Titans layers
        self.layers = nn.ModuleList([
            TitansMAC(
                dim=dim,
                num_heads=num_heads,
                memory_depth=memory_depth,
                dropout=dropout,
                num_persistent_tokens=num_persistent_tokens,
                segment_size=segment_size,
                conv_kernel_size=conv_kernel_size,
            )
            for _ in range(num_layers)
        ])
        
        # Output projection
        self.output_proj = nn.Linear(dim, 50304)  # Match vocab size
        
        # Layer normalization
        self.norm = nn.LayerNorm(dim)
        
    def reset_memories(self, batch_size: int = 1, device: torch.device = torch.device('cpu')):
        """Reset all memory states."""
        for layer in self.layers:
            layer.reset_memory(batch_size, device)
            
    def forward(self, input_ids: torch.Tensor, positions: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass for the complete model.
        
        Args:
            input_ids: Tensor of token ids [batch_size, seq_len]
            positions: Optional position ids [batch_size, seq_len]
            
        Returns:
            Logits for next token prediction [batch_size, seq_len, vocab_size]
        """
        batch_size, seq_len = input_ids.shape
        device = input_ids.device
        
        # Get token embeddings
        x = self.token_embedding(input_ids)
        
        # Add position embeddings if provided
        if positions is None:
            positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
        
        x = x + self.pos_embedding(positions)
        
        # Process through Titans layers
        for layer in self.layers:
            x = layer(x)
        
        # Final normalization
        x = self.norm(x)
        
        # Project to vocabulary
        logits = self.output_proj(x)
        
        return logits
    
# Example usage
model = TitansModel(
    dim=512,               # Embedding dimension
    num_layers=6,          # Number of Titans layers
    num_heads=8,           # Number of attention heads
    memory_depth=2,        # Depth of the memory MLP
    dropout=0.1,
    num_persistent_tokens=8,
    segment_size=512       # Size of segments for processing
)

# Reset memory states at the start of a new sequence
model.reset_memories(batch_size=1)

# Forward pass
input_ids = torch.randint(0, 50000, (1, 1024))  # Example input
logits = model(input_ids)

TypeError: linear(): argument 'input' (position 1) must be Tensor, not NoneType