# Tutorial 1: Introduction to Test-Time Training (TTT)

## ðŸ“š Introduction

Welcome to the first lesson in the **End-to-End Test-Time Training** tutorial series.

### The Problem: Memory Bottlenecks
Traditional Transformer models (like GPT-4, Llama 3) process text using **Attention**.
- To "remember" what you said 5 minutes ago, they store a **Key-Value Cache (KV Cache)**.
- **Short Context**: This is fine.
- **Long Context (1M+ tokens)**: The KV Cache becomes massive. Storing millions of previous tokens in GPU RAM is slow and expensive (Terabytes of RAM needed).

### The Solution: "Context as Weights"
**Test-Time Training (TTT)** proposes a radical shift.
Instead of storing the history in a static cache, we **train a neural network** on the history.

- **Old Way (Attention)**: `Memory = [List of past tokens]`
- **New Way (TTT)**: `Memory = [Weights of a Model]`

The "memory" of the conversation is compressed into the updated weights of this internal network. As a new token arrives, we run a gradient descent step to update these weights.

---

## 1. Defining the TTT Layer

Here we will build a **Toy Model** from scratch to visualize how weights can be updated on-the-fly during a forward pass.

### Why do we need a custom layer?
Standard PyTorch layers (`nn.Linear`) are designed to have fixed weights during inference. We need a layer where the weights (`inner_weights`) are **mutable** and updated *per token*.

### Key Components
1.  **`inner_weights` (The "Memory")**: In a standard RNN, the hidden state $h_t$ is a vector. In TTT, the hidden state is a **Matrix** $W_t$ that we learn.
2.  **`query`, `key`, `value` projections**: These are fixed weights learned during pre-training. They decide *what* to store in our memory matrix.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

class SimpleTTTLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, learning_rate):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
        
        # --- TRADITIONAL WEIGHTS (Fixed during inference) ---
        # These are "Meta-Parameters". They don't change at test time.
        # Their job is to process the input into signals that are useful for training the inner weights.
        self.query_proj = nn.Linear(input_dim, hidden_dim)
        self.key_proj = nn.Linear(input_dim, hidden_dim) 
        self.value_proj = nn.Linear(input_dim, input_dim) 
        
        # --- TTT WEIGHTS (Updated during inference) ---
        # This matrix acts as the "hidden state". It starts at zero (Empty Memory).
        # Shape: (Input, Hidden) mapping.
        self.inner_weights = nn.Parameter(torch.zeros(input_dim, hidden_dim))
        
    def forward(self, x_sequence):
        """
        Processing a sequence token-by-token using the TTT logic.
        x_sequence: (Batch, SeqLen, Dim)
        """
        batch_size, seq_len, _ = x_sequence.shape
        outputs = []
        
        # --- [CRITICAL STEP]: Reset Memory ---
        # Before reading a new document, we must wipe the short-term memory (inner weights).
        # We utilize .clone() so we don't permanently overwrite the initialization for future runs.
        current_W = self.inner_weights.clone() 
        
        print(f"\nProcessing sequence of length {seq_len}...")
        
        for t in range(seq_len):
            # Extract current token
            x_t = x_sequence[:, t, :] # (Batch, Dim)
            
            # =========================================================
            # INNER LOOP: The "Training" Phase
            # Use the current token to update our memory (current_W).
            # =========================================================
            
            # 1. Create training signals from input
            # k_t = What pattern are we looking for?
            # v_t_target = What is the value associated with that pattern?
            k_t = self.key_proj(x_t)
            v_t_target = self.value_proj(x_t)
            
            # 2. Forward pass through our "Memory Network" (current_W)
            # We check what our memory *currently* predicts given the key.
            # Prediction = k_t * W
            v_t_pred = torch.matmul(k_t, current_W.t()) 
            
            # 3. Calculate Loss (Reconstruction Error)
            # How well did our memory predict this token? 
            # If the error is high, it means this is 'new information' we need to store.
            loss = torch.mean((v_t_pred - v_t_target) ** 2)
            
            # 4. Update the Weights (Gradient Descent step)
            # We calculate how to change current_W to reduce the loss (i.e., remember this token).
            grad = torch.autograd.grad(loss, current_W, create_graph=True)[0]
            current_W = current_W - self.learning_rate * grad
            
            # =========================================================
            # OUTER LOOP: The "Inference" Phase
            # Use the UPDATED memory to generate/process the output.
            # =========================================================
            
            q_t = self.query_proj(x_t)
            
            # Process query using the *updated* memory.
            # This is equivalent to: Attention(Q, K, V)
            output = torch.matmul(q_t, current_W.t()) 
            outputs.append(output)
            
            # Visualization logs
            if t % 5 == 0:
                print(f"  Token {t}: Adaptation Loss = {loss.item():.6f}")

        return torch.stack(outputs, dim=1)

## 2. Running the Simulation

Now we run the model on a random sequence of inputs.

### What to look for:
Watch the `Adaptation Loss` printed in the loop.
- **High Loss**: The model is encountering new data.
- **Decreasing Loss**: The model is successfully "learning" the patterns in the sequence.

**Observation**: The variable `current_W` (our hidden state) is evolving *per token*. This variable is effectively a compressed representation of the entire history.

In [2]:
input_dim = 8
hidden_dim = 16
lr = 0.1 # High learning rate to make the adaptation obvious
seq_len = 20

print("Initializing TTT Layer...")
model = SimpleTTTLayer(input_dim, hidden_dim, lr)

# Create random input sequence
# Batch size = 1, Sequence Length = 20, Dimensions = 8
x = torch.randn(1, seq_len, input_dim)

print(f"Standard Input Shape: {x.shape} (Batch, Seq, Dim)")

# Run the Forward Pass (which includes the Inner Training Loop)
output = model(x)

print(f"Final Output Shape: {output.shape}")

print("\nâœ… Done! You have successfully run a Test-Time Training loop.")

Initializing TTT Layer...
Standard Input Shape: torch.Size([1, 20, 8]) (Batch, Seq, Dim)

Processing sequence of length 20...
  Token 0: Adaptation Loss = 0.566423
  Token 5: Adaptation Loss = 0.221849
  Token 10: Adaptation Loss = 0.126281
  Token 15: Adaptation Loss = 0.523686
Final Output Shape: torch.Size([1, 20, 8])

âœ… Done! You have successfully run a Test-Time Training loop.
