In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math # For SwiGLU

# --- Helper Modules ---

class SwiGLU(nn.Module):
    """
    SwiGLU activation function module, commonly used in Transformer MLPs.
    Based on https://arxiv.org/pdf/2002.05202.pdf
    """
    def __init__(self, dim_in, dim_out, bias=False):
        super().__init__()
        self.linear_gate = nn.Linear(dim_in, dim_out, bias=bias)
        self.linear_value = nn.Linear(dim_in, dim_out, bias=bias)

    def forward(self, x):
        # x: (..., dim_in)
        gate = self.linear_gate(x)
        value = self.linear_value(x)
        # Element-wise multiplication of sigmoid(gate) * value
        return F.silu(gate) * value # silu(x) = x * sigmoid(x)

class MLPBlock(nn.Module):
    """
    Standard MLP block found in Transformers, using SwiGLU.
    Typically, hidden_mult=4, but can be adjusted.
    """
    def __init__(self, hidden_dim, hidden_mult=4, dropout=0.1):
        super().__init__()
        dim_inner = hidden_dim * hidden_mult
        # The paper suggests making the intermediate dim 2/3 of dim_inner for SwiGLU
        # but we'll keep it simple here. Adjust if needed.
        # dim_inner_glu = int(2 * dim_inner / 3)

        self.fc1 = nn.Linear(hidden_dim, dim_inner * 2, bias=False) # Combined for SwiGLU gate/value
        self.act = nn.SiLU() # SwiGLU uses SiLU (Sigmoid Linear Unit) internally
        self.fc2 = nn.Linear(dim_inner, hidden_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Project up and split for SwiGLU parts
        gate_value = self.fc1(x).chunk(2, dim=-1)
        # Apply SwiGLU logic: silu(gate) * value
        hidden = self.act(gate_value[0]) * gate_value[1]
        hidden = self.dropout(hidden)
        # Project back down
        output = self.fc2(hidden)
        output = self.dropout(output)
        return output

In [None]:
# --- TTT Components ---

class TTTTask(nn.Module):
    """
    Defines the learnable projections for the multi-view reconstruction task
    within the TTT layer, as described in Section 2.3 and Figure 5.
    """
    def __init__(self, hidden_dim, projection_rank=None):
        """
        Args:
            hidden_dim (int): Dimension of the input/output features.
            projection_rank (int, optional): If provided, uses low-rank projections.
                                            Otherwise, uses full-rank projections.
                                            Defaults to None (full-rank).
        """
        super().__init__()
        self.hidden_dim = hidden_dim
        # Use projection_rank if specified, otherwise use full hidden_dim
        rank = projection_rank if projection_rank is not None else hidden_dim

        # Learnable outer-loop parameters theta_K, theta_V, theta_Q
        # These define the self-supervised task (Eq. 4) and output rule (Eq. 5)
        self.proj_K = nn.Linear(hidden_dim, rank, bias=False) # Training view
        self.proj_V = nn.Linear(hidden_dim, rank, bias=False) # Label view
        self.proj_Q = nn.Linear(hidden_dim, rank, bias=False) # Test view

        # The inner model 'f' will operate on dimensions of size 'rank'

    def forward(self, x):
        """
        Applies the projections to create the different views.

        Args:
            x (torch.Tensor): Input tensor of shape (..., hidden_dim).

        Returns:
            tuple: (x_train, x_label, x_test) tensors of shape (..., rank).
        """
        x_train = self.proj_K(x) # Used as input to f for the inner loss
        x_label = self.proj_V(x) # Used as the target for the inner loss
        x_test  = self.proj_Q(x) # Used as input to f for the layer's output
        return x_train, x_label, x_test

In [None]:
class TTTLayer(nn.Module):
    """
    Implements a Test-Time Training (TTT) layer based on the paper.
    Uses mini-batch updates (Section 2.4) and learnable components.
    The inner model 'f' is a linear transformation with LayerNorm and Residual.
    """
    def __init__(self, hidden_dim, inner_eta_base=0.1, ttt_batch_size=16, projection_rank=None):
        """
        Args:
            hidden_dim (int): Dimension of input features.
            inner_eta_base (float): Base learning rate for the inner loop (eta_base in Sec 2.7).
            ttt_batch_size (int): Mini-batch size for inner loop updates (b in Sec 2.4).
            projection_rank (int, optional): Rank for the task projections. Defaults to hidden_dim.
        """
        super().__init__()
        self.hidden_dim = hidden_dim
        self.inner_eta_base = inner_eta_base
        self.ttt_batch_size = ttt_batch_size
        self.projection_rank = projection_rank if projection_rank is not None else hidden_dim

        # Learnable multi-view task projections (theta_K, theta_V, theta_Q)
        self.task = TTTTask(hidden_dim, projection_rank=self.projection_rank)

        # Learnable initial weights for the inner linear model (theta_init = W_0 in Sec 2.7)
        # The inner model maps rank -> rank
        self.initial_W = nn.Parameter(torch.randn(self.projection_rank, self.projection_rank) * 0.01)

        # Learnable inner learning rate components (theta_lr in Sec 2.7)
        # Maps hidden_dim -> 1, output is used in sigmoid for gating eta
        self.theta_lr = nn.Linear(hidden_dim, 1, bias=False)

        # LayerNorm for the inner model f (Sec 2.7)
        self.inner_norm = nn.LayerNorm(self.projection_rank)

    def _inner_model_f(self, view, W):
        """
        Implements the inner model f(view; W) = view + LN(W @ view).
        This applies the linear transformation, LayerNorm, and residual connection.

        Args:
            view (torch.Tensor): Input view (e.g., x_train or x_test), shape (..., rank).
            W (torch.Tensor): Current inner model weights, shape (rank, rank).

        Returns:
            torch.Tensor: Output of the inner model, shape (..., rank).
        """
        # Linear transformation: W maps rank -> rank
        # view shape: (batch_size_token, rank)
        # W shape: (rank, rank)
        # W.t() shape: (rank, rank)
        # result shape: (batch_size_token, rank)
        transformed_view = F.linear(view, W) # Equivalent to torch.matmul(view, W.t())

        # Apply LayerNorm and Residual connection
        # output = view + self.inner_norm(transformed_view) # Paper: f(x) = x + LN(f_res(x))
        # Let's double check Eq 4, 5. f is applied to the *view*.
        # So the residual should be added to the *view*.
        output = view + self.inner_norm(transformed_view)
        return output

    def forward(self, input_seq):
        """
        Processes the input sequence using mini-batch TTT updates.

        Args:
            input_seq (torch.Tensor): Input sequence tensor of shape
                                      (seq_len, batch_size, hidden_dim).

        Returns:
            torch.Tensor: Output sequence tensor of shape
                          (seq_len, batch_size, hidden_dim).
                          Note: Output dim matches input dim, even though
                          inner model works on 'rank'. We need a final projection.
                          --> Let's rethink. The TTT layer should output hidden_dim.
                          The inner model f outputs rank. How does it map back?
                          The paper's Figures 1, 3, 5 show z_t = f(x_t; W_t).
                          If f outputs rank, z_t has rank dim. This needs to feed
                          into the next layer expecting hidden_dim.
                          Possibility 1: The output rule f(theta_Q x_t; W_t) implicitly
                                         includes a projection back to hidden_dim.
                          Possibility 2: The TTT layer itself includes an output projection.
                          Possibility 3: The *entire* hidden state of the RNN includes W
                                         and maybe other things, and the output rule combines them.
                          Let's assume the TTTLayer should output hidden_dim.
                          We can add a learnable output projection nn.Linear(rank, hidden_dim)
                          applied to the result of f(x_test, W_updated).

                          Let's add an output projection.
        """
        seq_len, batch_size, _ = input_seq.shape
        device = input_seq.device

        # Initialize inner model weights (W_0) - clone to avoid modifying parameter directly
        W = self.initial_W.clone().detach().requires_grad_(True) # Start fresh for each sequence forward pass

        outputs = []
        hidden_states_W = [W] # Store W at the start of each mini-batch

        # --- Mini-batch TTT Loop ---
        for t_start in range(0, seq_len, self.ttt_batch_size):
            t_end = min(t_start + self.ttt_batch_size, seq_len)
            current_batch_size = t_end - t_start

            # Get the current mini-batch of input tokens
            # Shape: (current_batch_size, batch_size, hidden_dim)
            x_batch = input_seq[t_start:t_end]

            # Reshape for processing: (current_batch_size * batch_size, hidden_dim)
            x_batch_flat = x_batch.reshape(-1, self.hidden_dim)

            # Get the weights at the start of this mini-batch (W_{t'})
            W_start_batch = hidden_states_W[-1] # Use W from end of last batch

            # --- Parallel Gradient Calculation (Conceptually) ---
            # We need gradients w.r.t W_start_batch for each token in the batch
            # Calculate views for the entire flattened batch
            # x_train/label/test_flat: (current_batch_size * batch_size, rank)
            x_train_flat, x_label_flat, x_test_flat = self.task(x_batch_flat)

            # Calculate predictions using the inner model f with W_start_batch
            # pred_flat: (current_batch_size * batch_size, rank)
            pred_flat = self._inner_model_f(x_train_flat, W_start_batch)

            # Calculate loss for the entire batch
            # Note: Ensure reduction='none' if we need per-token gradients later,
            # but for a single batch update, 'mean' is fine.
            loss = F.mse_loss(pred_flat, x_label_flat)

            # Calculate the single gradient for the entire batch w.r.t W_start_batch
            # This gradient represents the average direction over the mini-batch
            grad_W_batch = torch.autograd.grad(loss, W_start_batch, retain_graph=True)[0]

            # --- Calculate Learnable Learning Rate (eta) ---
            # Calculate eta per token based on original x_batch_flat
            # eta_gate_flat: (current_batch_size * batch_size, 1)
            eta_gate_flat = torch.sigmoid(self.theta_lr(x_batch_flat))
            # Average eta over the batch? Or apply per-token scaling to gradient?
            # Paper Eq 6 uses sum(G_s), where G_s = grad(l(W; x_s)).
            # Learnable eta (Sec 2.7) eta(x) = eta_base * sigmoid(theta_lr * x).
            # Let's assume eta is applied *after* summing gradients, using an average eta.
            avg_eta_multiplier = eta_gate_flat.mean()
            eta = self.inner_eta_base * avg_eta_multiplier

            # --- Single Weight Update for the Batch ---
            # Update W for the *next* mini-batch start state
            W_end_batch = W_start_batch - eta * grad_W_batch
            # Detach and require grad for the next iteration's gradient calculation
            W_end_batch = W_end_batch.detach().requires_grad_(True)
            hidden_states_W.append(W_end_batch)

            # --- Calculate Output for the Batch ---
            # Use the *updated* weights (W_end_batch) for prediction.
            # This is a simplification. The dual form (Sec 2.5, Appendix A)
            # computes the exact output z_t using W_t within the batch efficiently.
            # Using W_end_batch applies the average update effect to all tokens in the batch.
            # z_flat: (current_batch_size * batch_size, rank)
            z_flat = self._inner_model_f(x_test_flat, W_end_batch)

            # Reshape output back to (current_batch_size, batch_size, rank)
            z_batch = z_flat.reshape(current_batch_size, batch_size, self.projection_rank)
            outputs.append(z_batch)

        # Concatenate outputs from all mini-batches
        # output_seq_rank: (seq_len, batch_size, rank)
        output_seq_rank = torch.cat(outputs, dim=0)

        # Project back to hidden_dim - ADDED based on re-evaluation
        # If TTTLayer is meant to replace a standard RNN/Attention layer,
        # it should likely preserve the hidden dimension.
        # Create this projection lazily if it doesn't exist
        if not hasattr(self, 'output_proj'):
             self.output_proj = nn.Linear(self.projection_rank, self.hidden_dim, bias=False).to(device)

        output_seq = self.output_proj(output_seq_rank)

        return output_seq

In [None]:
class TTTModel(nn.Module):
    """
    Language Model using TTT layers, incorporating a Transformer-like backbone.
    """
    def __init__(self, vocab_size, hidden_dim, num_layers, inner_eta_base=0.1,
                 ttt_batch_size=16, projection_rank=None, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim

        # Token embedding
        self.embed_tokens = nn.Embedding(vocab_size, hidden_dim)
        self.dropout = nn.Dropout(dropout)

        # Stack of TTT Layers interleaved with MLP blocks and LayerNorms
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(nn.ModuleList([
                nn.LayerNorm(hidden_dim), # Pre-TTT Norm
                TTTLayer(hidden_dim,
                         inner_eta_base=inner_eta_base,
                         ttt_batch_size=ttt_batch_size,
                         projection_rank=projection_rank),
                nn.LayerNorm(hidden_dim), # Pre-MLP Norm
                MLPBlock(hidden_dim, dropout=dropout)
            ]))

        # Final LayerNorm
        self.final_norm = nn.LayerNorm(hidden_dim)

        # LM Head (predicts next token)
        self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)

        # Optional: Tie weights between embedding and LM head
        # self.embed_tokens.weight = self.lm_head.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
        # Initialize TTTLayer initial_W specifically?
        # It's already initialized in TTTLayer.__init__


    def forward(self, token_ids):
        """
        Forward pass through the TTT-based language model.

        Args:
            token_ids (torch.Tensor): Input token IDs, shape (seq_len, batch_size).

        Returns:
            torch.Tensor: Logits for next token prediction, shape
                          (seq_len, batch_size, vocab_size).
        """
        # 1. Embedding
        # token_ids: (seq_len, batch_size)
        # x: (seq_len, batch_size, hidden_dim)
        x = self.embed_tokens(token_ids)
        x = self.dropout(x)

        # 2. TTT Layers and MLP Blocks
        for ttt_norm, ttt_layer, mlp_norm, mlp_block in self.layers:
            # Residual connection around TTT Layer
            residual = x
            x = ttt_norm(x)
            x = ttt_layer(x) # TTT layer processes the sequence
            x = self.dropout(x)
            x = residual + x

            # Residual connection around MLP Block
            residual = x
            x = mlp_norm(x)
            x = mlp_block(x)
            # Dropout is already inside MLPBlock
            x = residual + x

        # 3. Final Norm
        x = self.final_norm(x)

        # 4. LM Head
        # logits: (seq_len, batch_size, vocab_size)
        logits = self.lm_head(x)

        return logits

In [None]:
# -------------------------
# Example Usage
# -------------------------
if __name__ == '__main__':
    # Hyperparameters for illustration.
    vocab_size = 10000
    hidden_dim = 512
    num_layers = 4
    inner_eta_base = 0.05 # Base learning rate for inner loop
    ttt_batch_size = 16   # Mini-batch size for TTT updates
    projection_rank = 128 # Optional: Use lower rank for projections (e.g., hidden_dim // 4)
    dropout = 0.1
    seq_len = 64 # Increased sequence length
    batch_size = 8

    # Instantiate the model
    model = TTTModel(vocab_size, hidden_dim, num_layers,
                     inner_eta_base=inner_eta_base,
                     ttt_batch_size=ttt_batch_size,
                     projection_rank=projection_rank,
                     dropout=dropout)

    print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Using device: {device}")

    # Create some dummy token IDs.
    # Shape: (seq_len, batch_size)
    token_ids = torch.randint(0, vocab_size, (seq_len, batch_size), device=device)

    # --- Forward pass ---
    print("\nRunning forward pass...")
    try:
        with torch.no_grad(): # Use no_grad for inference example
             if device == torch.device("cuda"):
                 # Use autocast for mixed precision on GPU
                 with torch.cuda.amp.autocast():
                     logits = model(token_ids)
             else:
                 logits = model(token_ids)
        print("Forward pass successful!")
        print("Logits shape:", logits.shape) # Expected: (seq_len, batch_size, vocab_size)

        # --- Optional: Check backward pass (requires gradients) ---
        # print("\nRunning backward pass (requires gradients)...")
        # model.train()
        # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) # Example optimizer
        # optimizer.zero_grad()
        #
        # if device == torch.device("cuda"):
        #     with torch.cuda.amp.autocast():
        #         logits = model(token_ids)
        #         # Dummy loss for testing backward pass
        #         dummy_loss = logits.mean()
        # else:
        #     logits = model(token_ids)
        #     dummy_loss = logits.mean()
        #
        # print("Dummy Loss:", dummy_loss.item())
        # dummy_loss.backward()
        # optimizer.step()
        # print("Backward pass successful!")

    except Exception as e:
        print(f"\nAn error occurred during the forward/backward pass: {e}")
        import traceback
        traceback.print_exc()