# HMT Data Pipeline: Hands-On Exploration

Now that you understand WikiText-103, let's use the **HMT data pipeline** you just built!

**In this notebook:**
1. Load data using your `WikiTextDataset` and `LongContextDataLoader`
2. Inspect batches and understand padding
3. Simulate how HMT will process long articles
4. Compare standard transformer truncation vs HMT segmentation
5. Prepare for implementing HMT memory components

**Learning Goals:**
- Understand the data flow from raw text → tokens → batches → HMT segments
- See why HMT's approach is superior for long contexts
- Get intuition for the three-level memory hierarchy

In [None]:
# Setup
import sys
sys.path.append('..')

import torch
from transformers import GPT2Tokenizer
import numpy as np
import matplotlib.pyplot as plt

from hmt.data import WikiTextDataset, LongContextDataLoader, create_dataloaders
from hmt.utils import get_device
from hmt.config import HMTConfig

# Check device
device = get_device()
print(f"🚀 Using device: {device}")
print(f"   MPS available: {torch.backends.mps.is_available()}")

## 1. Loading Data with HMT Pipeline

Let's use the data utilities you built in `src/hmt/data.py`.

In [None]:
# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

print(f"Tokenizer: GPT-2")
print(f"Vocabulary size: {len(tokenizer)}")
print(f"Pad token: '{tokenizer.pad_token}' (id: {tokenizer.pad_token_id})")

In [None]:
# Create dataloaders using your utility function
print("Creating dataloaders...\n")

dataloaders = create_dataloaders(
    tokenizer=tokenizer,
    batch_size=4,
    min_length=128,
    max_length=4096,  # Limit for notebook speed
)

train_loader = dataloaders['train']
val_loader = dataloaders['validation']
test_loader = dataloaders['test']

print(f"✅ Dataloaders created:")
print(f"  Train:      {len(train_loader)} batches")
print(f"  Validation: {len(val_loader)} batches")
print(f"  Test:       {len(test_loader)} batches")

## 2. Inspecting a Batch

Let's look at what a batch looks like.

In [None]:
# Get a batch from validation set
batch = next(iter(val_loader))

print("📦 Batch structure:")
print(f"  Keys: {list(batch.keys())}")
print(f"\n  input_ids shape: {batch['input_ids'].shape}")
print(f"  attention_mask shape: {batch['attention_mask'].shape}")

batch_size, seq_len = batch['input_ids'].shape
print(f"\n  Batch size: {batch_size} articles")
print(f"  Sequence length: {seq_len} tokens (padded to longest in batch)")

In [None]:
# Analyze each article in the batch
print("\n📊 Article lengths in this batch:\n")

for i in range(batch_size):
    # Count non-padded tokens using attention mask
    actual_length = batch['attention_mask'][i].sum().item()
    padding = seq_len - actual_length
    
    print(f"  Article {i}:")
    print(f"    Actual tokens: {actual_length}")
    print(f"    Padding:       {padding}")
    print(f"    Efficiency:    {actual_length/seq_len*100:.1f}%")
    print()

In [None]:
# Decode and display the first article
article_0 = batch['input_ids'][0]
article_0_length = batch['attention_mask'][0].sum().item()

# Remove padding
article_0_clean = article_0[:article_0_length]

# Decode
text = tokenizer.decode(article_0_clean)

print("📖 First article in batch:")
print("="*80)
print(text[:500])
print("...")
print(f"\nTotal length: {article_0_length} tokens")

## 3. Simulating HMT Segmentation

Now let's see how HMT would process this article with its segmentation approach.

In [None]:
# HMT configuration
config = HMTConfig(
    segment_length=512,
    num_memory_embeddings=300,
    sensory_memory_size=32,
)

print("⚙️  HMT Configuration:")
print(f"  Segment length (L):          {config.segment_length} tokens")
print(f"  Representation length (j):   {config.representation_length} tokens (L/2)")
print(f"  Sensory memory size (k):     {config.sensory_memory_size} tokens")
print(f"  Long-term memory cache (N):  {config.num_memory_embeddings} embeddings")

In [None]:
def visualize_hmt_processing(tokens, config):
    """
    Visualize how HMT processes a long sequence.
    
    Returns information about each segment and the processing flow.
    """
    L = config.segment_length
    k = config.sensory_memory_size
    j = config.representation_length
    
    num_tokens = len(tokens)
    num_segments = (num_tokens + L - 1) // L
    
    print(f"\n🔄 HMT Processing Pipeline:")
    print("="*80)
    print(f"Total tokens: {num_tokens}")
    print(f"Number of segments: {num_segments}")
    print(f"\nProcessing flow:\n")
    
    segments_info = []
    
    for seg_id in range(num_segments):
        # Current segment bounds
        start = seg_id * L
        end = min(start + L, num_tokens)
        seg_len = end - start
        
        # Sensory memory from previous segment
        if seg_id > 0:
            sensory_start = max(0, start - k)
            sensory_len = start - sensory_start
        else:
            sensory_len = 0
        
        print(f"📍 Segment {seg_id}:")
        print(f"  Position: tokens {start:4d} - {end:4d}")
        print(f"  Current segment:  {seg_len:3d} tokens")
        
        if seg_id > 0:
            print(f"  Sensory memory:   {sensory_len:3d} tokens (from segment {seg_id-1})")
            print(f"  Memory retrieval: Query {seg_id} past memory embeddings")
            print(f"  Total context:    {seg_len + sensory_len:3d} tokens + retrieved memories")
        else:
            print(f"  Sensory memory:   None (first segment)")
            print(f"  Memory retrieval: None (first segment)")
            print(f"  Total context:    {seg_len:3d} tokens")
        
        print(f"  Encoding:         First {j} tokens → representation embedding")
        print(f"  Output:           New memory embedding → cache")
        print()
        
        segments_info.append({
            'id': seg_id,
            'start': start,
            'end': end,
            'length': seg_len,
            'sensory': sensory_len,
        })
    
    return segments_info

# Visualize processing for our article
segments_info = visualize_hmt_processing(article_0_clean, config)

## 4. Comparing Approaches: Standard Transformer vs HMT

Let's visualize the key difference.

In [None]:
# Compare coverage
article_len = len(article_0_clean)
gpt2_limit = 1024

print("\n⚔️  Standard Transformer vs HMT Comparison:")
print("="*80)

print(f"\nArticle length: {article_len} tokens\n")

# Standard GPT-2
print("❌ Standard GPT-2:")
if article_len > gpt2_limit:
    tokens_lost = article_len - gpt2_limit
    pct_lost = (tokens_lost / article_len) * 100
    print(f"  Processes:  {gpt2_limit:4d} tokens")
    print(f"  TRUNCATES:  {tokens_lost:4d} tokens ({pct_lost:.1f}%)")
    print(f"  Memory:     O(L²) = O({gpt2_limit}²) = {gpt2_limit**2:,} operations")
    print(f"  ⚠️  Context window slides - loses information!")
else:
    print(f"  Processes:  {article_len:4d} tokens (fits in context)")
    print(f"  Memory:     O(L²) = O({article_len}²) = {article_len**2:,} operations")

# HMT
print(f"\n✅ HMT with Hierarchical Memory:")
num_segments = len(segments_info)
total_processed = sum(seg['length'] for seg in segments_info)
print(f"  Processes:  {total_processed:4d} tokens (FULL article)")
print(f"  Segments:   {num_segments} segments of {config.segment_length} tokens")
print(f"  Memory:     O(L) per segment = {num_segments} × {config.segment_length} = {num_segments * config.segment_length:,} operations")
print(f"  Cache:      {config.num_memory_embeddings} memory embeddings")
print(f"  ✅ Hierarchical memory - preserves ALL context!")

# Efficiency gain
if article_len > gpt2_limit:
    print(f"\n🚀 Efficiency Gain:")
    standard_ops = gpt2_limit ** 2
    hmt_ops = num_segments * config.segment_length
    speedup = standard_ops / hmt_ops
    print(f"  HMT is ~{speedup:.1f}x more efficient while processing MORE tokens!")

In [None]:
# Visualize coverage
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 6))

# Standard GPT-2
ax1.barh(0, min(article_len, gpt2_limit), height=0.5, 
         color='blue', alpha=0.7, label='Processed')
if article_len > gpt2_limit:
    ax1.barh(0, article_len - gpt2_limit, left=gpt2_limit, 
             height=0.5, color='red', alpha=0.7, label='Truncated (LOST)')
ax1.set_xlim(0, article_len)
ax1.set_ylim(-0.5, 0.5)
ax1.set_xlabel('Token Position')
ax1.set_title('Standard GPT-2: Truncates at 1024 tokens')
ax1.legend(loc='upper right')
ax1.set_yticks([])
ax1.grid(alpha=0.3, axis='x')

# HMT
colors = plt.cm.viridis(np.linspace(0, 1, len(segments_info)))
for i, seg in enumerate(segments_info):
    ax2.barh(0, seg['length'], left=seg['start'], height=0.5,
             color=colors[i], alpha=0.8, edgecolor='black', linewidth=1,
             label=f"Segment {seg['id']}" if i < 5 else '')
    
    # Show sensory memory overlap
    if seg['sensory'] > 0:
        ax2.barh(0, seg['sensory'], left=seg['start'] - seg['sensory'], 
                height=0.3, color='orange', alpha=0.5)

ax2.set_xlim(0, article_len)
ax2.set_ylim(-0.5, 0.5)
ax2.set_xlabel('Token Position')
ax2.set_title('HMT: Processes full article via segmentation + memory')
if len(segments_info) <= 5:
    ax2.legend(loc='upper right')
ax2.set_yticks([])
ax2.grid(alpha=0.3, axis='x')

plt.tight_layout()
plt.show()

print("\n📊 Visualization:")
print("  Top: Standard GPT-2 (blue = processed, red = truncated)")
print("  Bottom: HMT segments (colored bars) with sensory memory overlap (orange)")

## 5. Understanding the Three-Level Memory Hierarchy

Let's trace what happens during HMT processing.

In [None]:
def explain_memory_hierarchy(segment_id, config):
    """
    Explain what's in each memory level for a given segment.
    """
    print(f"\n🧠 Memory Hierarchy at Segment {segment_id}:")
    print("="*80)
    
    # 1. Sensory Memory
    print(f"\n1️⃣  SENSORY MEMORY (k={config.sensory_memory_size} tokens):")
    if segment_id == 0:
        print(f"    - EMPTY (first segment, no previous context)")
    else:
        print(f"    - Contains: Last {config.sensory_memory_size} tokens from Segment {segment_id-1}")
        print(f"    - Purpose: Provides local continuity between segments")
        print(f"    - Like: Short-term buffer, prevents abrupt context switches")
    
    # 2. Short-term Memory
    print(f"\n2️⃣  SHORT-TERM MEMORY (L={config.segment_length} tokens):")
    print(f"    - Contains: Current segment being processed")
    print(f"    - Purpose: Active processing by backbone transformer (GPT-2)")
    print(f"    - Like: Working memory, immediate focus of attention")
    
    # 3. Long-term Memory
    print(f"\n3️⃣  LONG-TERM MEMORY (N={config.num_memory_embeddings} embeddings):")
    if segment_id == 0:
        print(f"    - EMPTY (first segment, no history yet)")
    else:
        print(f"    - Contains: {segment_id} memory embeddings from past segments")
        print(f"    - Purpose: Compressed representation of ALL previous context")
        print(f"    - Retrieval: Cross-attention finds most relevant {min(segment_id, 5)} memories")
        print(f"    - Like: Episodic memory, recalls relevant past experiences")
    
    # Combined Context
    print(f"\n📋 TOTAL CONTEXT for Segment {segment_id}:")
    sensory_tokens = 0 if segment_id == 0 else config.sensory_memory_size
    current_tokens = config.segment_length
    memory_embeddings = segment_id
    
    print(f"    Sensory:    {sensory_tokens:3d} tokens")
    print(f"    Current:    {current_tokens:3d} tokens")
    print(f"    Retrieved:  {memory_embeddings:3d} memory embeddings")
    print(f"    → Effectively has access to information from ALL {segment_id * config.segment_length + current_tokens} past tokens!")

# Explain for different segments
for seg_id in [0, 1, 2, len(segments_info)-1]:
    if seg_id < len(segments_info):
        explain_memory_hierarchy(seg_id, config)

## 6. Device Compatibility Test

Let's verify everything works on your device (MPS/CUDA/CPU).

In [None]:
# Move batch to device
print(f"💻 Testing on device: {device}\n")

batch_on_device = {
    k: v.to(device) for k, v in batch.items()
}

print(f"✅ Batch moved to {device}")
print(f"  input_ids device: {batch_on_device['input_ids'].device}")
print(f"  Shape: {batch_on_device['input_ids'].shape}")

# Simple computation test
input_ids = batch_on_device['input_ids']
attention_mask = batch_on_device['attention_mask']

# Count tokens per article (respecting padding)
tokens_per_article = attention_mask.sum(dim=1)

print(f"\n  Computation test:")
print(f"    Tokens per article: {tokens_per_article.tolist()}")
print(f"    Computed on: {tokens_per_article.device}")

print(f"\n🚀 Data pipeline is ready for HMT training on {device}!")

## 7. Key Takeaways & Next Steps

**What you've learned:**

1. **Data Pipeline:**
   - `WikiTextDataset` loads and tokenizes articles without truncation
   - `LongContextDataLoader` batches variable-length sequences efficiently
   - Your pipeline preserves full articles for long-context processing

2. **HMT Segmentation:**
   - Long articles are split into manageable segments (L=512 tokens)
   - Each segment is processed independently by the backbone
   - Memory mechanisms connect segments to maintain full context

3. **Three-Level Memory:**
   - **Sensory:** Last k=32 tokens for local continuity
   - **Short-term:** Current L=512 token segment being processed
   - **Long-term:** N=300 compressed embeddings from all past segments

4. **Efficiency:**
   - Standard transformers: O(L²) attention, must truncate
   - HMT: O(L) per segment, can process unlimited length

**Next Steps:**
1. ✅ Data pipeline is ready!
2. 🚧 Implement the three memory components:
   - `RepresentationEncoder` - Summarize segments
   - `MemorySearch` - Retrieve relevant memories
   - `MemoryEmbeddingGenerator` - Create long-term memories
3. 🚧 Connect components in `HMT.forward()`
4. 🚧 Train on WikiText-103!

Ready to build the memory components? 🧠

## 8. Optional: Explore Your Own Examples

Try modifying the config and see how it affects segmentation!

In [None]:
# Experiment: What if we use different segment lengths?

for L in [256, 512, 1024]:
    config_test = HMTConfig(segment_length=L, sensory_memory_size=32)
    num_segs = (len(article_0_clean) + L - 1) // L
    
    print(f"\nSegment length = {L}:")
    print(f"  Number of segments: {num_segs}")
    print(f"  Complexity: {num_segs} × O({L}) = O({num_segs * L})")
    print(f"  Trade-off: {'Fewer segments, more compute per segment' if L >= 512 else 'More segments, less compute per segment'}")

print("\n🤔 Question: What's the optimal segment length?")
print("   Answer: Depends on backbone model capacity and memory constraints!")