# Tutorial 2: Foundational Concepts for Long Context

## 1. Overview
Before we can train models on long contexts, we must ensure they can *represent* long contexts mathematically.

### The Challenges
1.  **Positional Encoding**: How does the model know that token #100,000 comes after token #99,999? Standard techniques (like RoPE) often break down or "alias" at lengths they weren't trained on.
2.  **Memory Management**: You cannot fit a 1M token sequence into GPU memory all at once for processing. 

### In this Notebook
We explore **two key techniques** required for TTT:
1.  **RoPE Scaling**: Adjusting the "Theta" parameter to handle millions of tokens without mathematical breakdown.
2.  **Chunking**: The strategy of breaking infinite streams into manageable blocks.

## 2. RoPE Scaling (The 'Theta' Parameter)

### Why is this important?
Rotary Embeddings work by rotating the vector representation of a token. The angle of rotation corresponds to its position.
- **The Aliasing Problem**: If `Theta` (the base frequency) is small, the vectors rotate fast. For very long sequences, a token at position 100,000 might have rotated so many times that it looks identical to a token at position 50. This confuses the model.
- **The Solution**: Increase `Theta`. This slows down the rotation, ensuring every position up to millions of tokens has a unique angle.

Let's visualize the math behind this.

In [1]:
import torch

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Calculates the complex frequencies for RoPE.
    
    Args:
        dim: Head dimension (size of the vector being rotated)
        end: The maximum sequence length we expect to see
        theta: The base frequency scaling factor (The 'Rotation Speed')
    """
    # 1. Calculate frequencies: 1 / theta^(i / dim)
    # This creates a spectrum of speeds: some dimensions rotate fast, others slow.
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    
    # 2. Generate positions [0, 1, ..., end]
    t = torch.arange(end, device=freqs.device) 
    
    # 3. Outer product: combine positions with frequencies
    # This gives us the specific angle for every position at every dimension.
    freqs = torch.outer(t, freqs).float()  
    
    # 4. Convert to complex numbers (polar form) for efficient rotation
    # e^(i * angle) = cos(angle) + i*sin(angle)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 
    return freqs_cis

### Experiment: Standard vs. Scaled Theta

In the TTT paper, they recommend scaling `Theta` from the standard **10,000** (used in Llama 2) to **10,000,000** (10 Million).

Let's see how this affects our ability to generate positions for a 128k token document.

In [2]:
dim = 64

# Scenario A: Standard Model (e.g., Llama 2 Base)
# Designed for 4k-8k context.
seq_len_short = 8192
theta_short = 10000.0 

print(f"\n1. Standard Context ({seq_len_short} tokens)")
freqs_short = precompute_freqs_cis(dim, seq_len_short, theta_short)
print(f"   Frequency Table Shape: {freqs_short.shape}")
print("   Theta: 10,000 (Fast Rotation)")


# Scenario B: TTT / Long Context
# Designed for 128k - 1M context.
# As per TTT paper, we scale theta to 10 Million to prevent aliasing.
seq_len_long = 128_000
theta_long = 10_000_000.0

print(f"\n2. Long Context ({seq_len_long} tokens)")
freqs_long = precompute_freqs_cis(dim, seq_len_long, theta_long)
print(f"   Frequency Table Shape: {freqs_long.shape}")
print(f"   Theta: {theta_long:,.0f} (Slow Rotation)")
print("   Note: This larger theta keeps positions unique over long distances.")


1. Standard Context (8192 tokens)
   Frequency Table Shape: torch.Size([8192, 32])
   Theta: 10,000 (Fast Rotation)

2. Long Context (128000 tokens)
   Frequency Table Shape: torch.Size([128000, 32])
   Theta: 10,000,000 (Slow Rotation)
   Note: This larger theta keeps positions unique over long distances.


## 3. Streaming & Chunking

### The TTT Advantage
In standard Transformers, processing a 1 Million token document requires storing all 1M tokens in memory ($O(N)$ or $O(N^2)$). 

In **TTT**, we don't process the whole document at once. We process it in **Chunks**.

1.  **Read Chunk 1**: Update Weights $\rightarrow$ Discard tokens.
2.  **Read Chunk 2**: Update Weights $\rightarrow$ Discard tokens.
3.  **Read Chunk 3**: ...

The information persists in the **Updated Weights** ($W_t$), not in the token cache. This is why TTT has constant memory usage ($O(1)$) regardless of sequence length.

In [3]:
# Simulation of a massive document stream
total_tokens = 1_000_000 # A very long book (approx 1MB of text)
chunk_size = 4096      # The size that typically fits in GPU VRAM

num_chunks = total_tokens // chunk_size

print(f"\n3. Streaming Simulation")
print(f"   Total Tokens: {total_tokens:,}")
print(f"   Chunk Size: {chunk_size}")
print(f"   Total Updates Required: {num_chunks}")

print("\n   TTT Flow:")
print("   [Start] -> [Chunk 1] -> [Update Weights] -> [Chunk 2] -> [Update Weights] ... -> [End]")
print("   Result: The final model has 'read' the whole book but never held it all in RAM at once.")


3. Streaming Simulation
   Total Tokens: 1,000,000
   Chunk Size: 4096
   Total Updates Required: 244

   TTT Flow:
   [Start] -> [Chunk 1] -> [Update Weights] -> [Chunk 2] -> [Update Weights] ... -> [End]
   Result: The final model has 'read' the whole book but never held it all in RAM at once.
