# Complete HMT Pipeline: Understanding the Three-Level Memory Hierarchy

**Paper Reference:** [Hierarchical Memory Transformer for Efficient Long Context Language Processing](https://arxiv.org/abs/2405.06067)

**Learning Objectives:**
1. Understand how **MemoryEmbeddingGenerator** creates compressed memory representations
2. See the complete **HMT forward pass** in action with all components integrated
3. Visualize the **three-level memory hierarchy**: sensory, short-term, and long-term
4. Process long sequences and observe memory cache evolution
5. Compare HMT vs standard transformer behavior

---

## Setup

In [None]:
import sys
sys.path.append('../src')

import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

from hmt import HMT, HMTConfig
from hmt.memory import MemoryEmbeddingGenerator
from hmt.utils import get_device

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

# Get device
device = get_device()
print(f"Using device: {device}")
print(f"Device type: {device.type if hasattr(device, 'type') else device}")

---

## Part 1: Memory Embedding Generation

**Paper Section 3.3:** Memory embeddings are compressed representations of processed segments.

### Equation 4: $m_n = \text{compress}(\text{BBM}([k_n || H_n || P_n]))$

Where:
- $k_n$: Sensory memory (last k tokens from previous segment)
- $H_n$: Current segment (L tokens)
- $P_n$: Retrieved memory from cache
- $m_n$: Generated memory embedding to store in cache

### Understanding Different Extraction Strategies

In [None]:
# Load a small model for demonstration
print("Loading GPT-2...")
backbone = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
backbone.eval()

# Create config
config = HMTConfig(
    segment_length=128,
    representation_length=64,
    num_memory_embeddings=50,
    sensory_memory_size=16,
    hidden_dim=768,  # GPT-2 hidden size
)

print(f"\nHMT Configuration:")
print(f"  Segment length (L): {config.segment_length}")
print(f"  Representation length (j): {config.representation_length}")
print(f"  Memory cache size (N): {config.num_memory_embeddings}")
print(f"  Sensory memory size (k): {config.sensory_memory_size}")
print(f"  Hidden dimension: {config.hidden_dim}")

In [None]:
# Initialize MemoryEmbeddingGenerator
mem_gen = MemoryEmbeddingGenerator(config).to(device)

# Simulate backbone output (hidden states)
batch_size = 1
seq_len = 100
simulated_hidden_states = torch.randn(batch_size, seq_len, config.hidden_dim).to(device)

print("Testing different extraction strategies:\n")

strategies = ["last", "mean", "max", "cls"]
strategy_embeddings = {}

with torch.no_grad():
    for strategy in strategies:
        embedding = mem_gen(simulated_hidden_states, extraction_strategy=strategy)
        strategy_embeddings[strategy] = embedding
        print(f"  {strategy:8s} → shape: {embedding.shape}, mean: {embedding.mean():.4f}, std: {embedding.std():.4f}")

print("\n📚 Strategy Explanations:")
print("  • 'last':  Uses final token (best for causal LMs like GPT)")
print("  • 'mean':  Average over all tokens (balanced representation)")
print("  • 'max':   Max pool (captures salient features)")
print("  • 'cls':   First token (BERT-style, less common for GPT)")

In [None]:
# Visualize the differences between strategies
fig, axes = plt.subplots(1, len(strategies), figsize=(16, 4))

for idx, (strategy, embedding) in enumerate(strategy_embeddings.items()):
    emb_np = embedding.cpu().numpy().flatten()[:100]  # Show first 100 dimensions
    axes[idx].bar(range(len(emb_np)), emb_np, alpha=0.7)
    axes[idx].set_title(f"{strategy.capitalize()} Strategy", fontsize=12, fontweight='bold')
    axes[idx].set_xlabel("Dimension")
    axes[idx].set_ylabel("Value")
    axes[idx].axhline(y=0, color='r', linestyle='--', alpha=0.3)

plt.suptitle("Memory Embedding Patterns Across Different Extraction Strategies", 
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("\n🔍 Observation: Different strategies produce different embedding patterns.")
print("   The 'last' strategy (paper's approach) captures final state after full context processing.")

---

## Part 2: Complete HMT Forward Pass

**Paper Algorithm 1:** Processing long sequences with hierarchical memory

### The Three-Level Memory Hierarchy

1. **Sensory Memory (k=16 tokens)**: 
   - Preserves last 16 tokens from previous segment
   - Provides local continuity across segments
   - Sliding window that moves with each segment

2. **Short-term Memory (L=128 tokens)**:
   - Current segment being processed
   - Standard transformer attention within this window
   - Processes current context

3. **Long-term Memory (N=50 embeddings)**:
   - Cache of compressed memory embeddings from past segments
   - Retrieved via cross-attention when relevant
   - Enables access to distant context without O(L²) complexity

In [None]:
# Initialize HMT
hmt = HMT(backbone, config).to(device)
hmt.eval()

print("HMT Model Initialized!")
print(f"\nMemory Components:")
print(f"  ✓ RepresentationEncoder")
print(f"  ✓ MemorySearch")
print(f"  ✓ MemoryEmbeddingGenerator")
print(f"  ✓ Memory Cache (FIFO queue, max size: {config.num_memory_embeddings})")
print(f"  ✓ Sensory Memory (last {config.sensory_memory_size} tokens)")

### Step-by-Step: Processing a Long Article

In [None]:
# Sample long text (from WikiText-like article)
long_text = """
The Hierarchical Memory Transformer is a novel architecture designed to handle long-context 
language processing efficiently. Traditional transformers face quadratic complexity with sequence 
length, making them impractical for very long documents. HMT addresses this by introducing a 
three-level memory hierarchy inspired by human cognition.

The first level, sensory memory, preserves the most recent tokens from the previous segment, 
ensuring local continuity. The second level, short-term memory, processes the current segment 
of fixed length L. The third level, long-term memory, maintains a cache of compressed embeddings 
from all previous segments, allowing the model to retrieve relevant distant context when needed.

This design reduces computational complexity from O(L²) to O(L) per segment, while still enabling 
access to the entire context through the memory retrieval mechanism. Cross-attention is used to 
search the long-term cache for relevant memories based on the current segment's representation.

The memory embeddings are generated by processing augmented segments through the backbone model 
and extracting compressed representations. These embeddings capture the essential information 
from each segment in a fixed-size vector, which is then stored in the cache for future retrieval.

Experimental results show that HMT achieves competitive performance on long-context language 
modeling tasks while using significantly less memory and computation than standard transformers. 
The architecture is plug-and-play, working with any pre-trained decoder-only transformer without 
modification to the backbone model.
"""

# Tokenize
inputs = tokenizer(long_text, return_tensors="pt", truncation=False)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)

seq_length = input_ids.shape[1]
num_segments = (seq_length + config.segment_length - 1) // config.segment_length

print(f"Input Text Statistics:")
print(f"  Total tokens: {seq_length}")
print(f"  Number of segments: {num_segments}")
print(f"  Tokens per segment: {config.segment_length}")
print(f"  Last segment size: {seq_length % config.segment_length if seq_length % config.segment_length != 0 else config.segment_length}")

In [None]:
# Clear memory before processing
hmt.clear_memory()
print("Memory cleared. Starting fresh...\n")

# Track memory evolution
memory_stats_timeline = []

# Process with HMT
with torch.no_grad():
    outputs = hmt(input_ids, attention_mask=attention_mask, use_memory=True)

# Get final memory stats
final_stats = hmt.get_memory_stats()

print("✅ Processing complete!\n")
print(f"Final Memory State:")
print(f"  Cache size: {final_stats['cache_size']} / {final_stats['max_cache_size']}")
print(f"  Sensory memory active: {final_stats['sensory_memory_active']}")
print(f"  Sensory memory size: {final_stats['sensory_memory_size']} tokens")

print(f"\nOutput shape: {outputs['logits'].shape}")
print(f"Expected: [batch_size=1, seq_len={seq_length}, vocab_size=50257]")

### Visualizing Memory Evolution

In [None]:
# Process again, tracking cache growth at each segment
hmt.clear_memory()
cache_sizes = []

# Manually segment and track
input_embeddings = hmt.embedding_layer(input_ids)

for seg_idx in range(num_segments):
    start = seg_idx * config.segment_length
    end = min(start + config.segment_length, seq_length)
    
    # Process segment (simplified - just track cache growth)
    segment_ids = input_ids[:, start:end]
    
    with torch.no_grad():
        _ = hmt(segment_ids, use_memory=True)
    
    cache_sizes.append(len(hmt.memory_cache))

# Plot cache growth
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(range(1, num_segments + 1), cache_sizes, marker='o', linewidth=2, markersize=8)
plt.axhline(y=config.num_memory_embeddings, color='r', linestyle='--', 
            label=f'Max Cache Size (N={config.num_memory_embeddings})')
plt.xlabel('Segment Number', fontsize=12)
plt.ylabel('Cache Size', fontsize=12)
plt.title('Long-term Memory Cache Growth Over Segments', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
segment_positions = np.arange(num_segments) * config.segment_length
segment_widths = [config.segment_length] * (num_segments - 1) + [seq_length % config.segment_length or config.segment_length]
colors = plt.cm.viridis(np.linspace(0, 1, num_segments))

for i, (pos, width) in enumerate(zip(segment_positions, segment_widths)):
    plt.barh(0, width, left=pos, height=0.5, color=colors[i], 
             label=f'Seg {i+1}' if i < 5 or i == num_segments-1 else None)

plt.xlabel('Token Position', fontsize=12)
plt.title('Sequence Segmentation', fontsize=14, fontweight='bold')
plt.yticks([])
plt.legend(loc='upper right', ncol=2)

plt.tight_layout()
plt.show()

print("\n📊 Insights:")
print(f"  • Cache grows linearly with segments (FIFO queue)")
print(f"  • Each segment adds one memory embedding to cache")
print(f"  • Sequence is divided into {num_segments} segments for processing")

---

## Part 3: Ablation Study - With vs Without Memory

**Key Question:** How does memory retrieval affect HMT's processing?

In [None]:
# Test shorter text for clearer comparison
test_text = "The transformer architecture revolutionized natural language processing by introducing "
test_text += "self-attention mechanisms. However, long sequences remain computationally expensive. "
test_text += "Hierarchical memory transformers address this challenge efficiently."

test_inputs = tokenizer(test_text, return_tensors="pt", truncation=False)
test_ids = test_inputs["input_ids"].to(device)

print(f"Test sequence length: {test_ids.shape[1]} tokens\n")

# Process WITH memory
hmt.clear_memory()
with torch.no_grad():
    outputs_with_mem = hmt(test_ids, use_memory=True)

stats_with = hmt.get_memory_stats()

# Process WITHOUT memory
hmt.clear_memory()
with torch.no_grad():
    outputs_without_mem = hmt(test_ids, use_memory=False)

stats_without = hmt.get_memory_stats()

print("\n📊 Comparison:")
print(f"\nWith Memory:")
print(f"  Cache size: {stats_with['cache_size']}")
print(f"  Sensory memory: {'Active' if stats_with['sensory_memory_active'] else 'Inactive'}")
print(f"  Output shape: {outputs_with_mem['logits'].shape}")

print(f"\nWithout Memory (Ablation):")
print(f"  Cache size: {stats_without['cache_size']}")
print(f"  Sensory memory: {'Active' if stats_without['sensory_memory_active'] else 'Inactive'}")
print(f"  Output shape: {outputs_without_mem['logits'].shape}")

# Check if outputs differ
output_diff = (outputs_with_mem['logits'] - outputs_without_mem['logits']).abs().mean().item()
print(f"\nMean absolute difference in logits: {output_diff:.6f}")
print(f"Outputs are {'identical' if output_diff < 1e-6 else 'different'} (memory {'does not affect' if output_diff < 1e-6 else 'affects'} processing)")

---

## Part 4: Understanding the Augmented Input

**Paper Section 3.3:** Each segment is augmented with:
- Sensory memory (k tokens from previous segment)
- Current segment (L tokens)
- Retrieved memory embedding (1 pseudo-token)

In [None]:
# Visualize augmented input structure
fig, ax = plt.subplots(figsize=(14, 6))

# Example for segment 3 (has sensory memory)
sensory_size = config.sensory_memory_size
segment_size = config.segment_length
memory_token = 1

total_augmented = sensory_size + segment_size + memory_token

# Plot boxes
ax.barh(0, sensory_size, left=0, height=0.6, 
        color='lightblue', edgecolor='black', linewidth=2, label='Sensory Memory (k=16)')
ax.barh(0, segment_size, left=sensory_size, height=0.6, 
        color='lightgreen', edgecolor='black', linewidth=2, label='Current Segment (L=128)')
ax.barh(0, memory_token, left=sensory_size + segment_size, height=0.6, 
        color='coral', edgecolor='black', linewidth=2, label='Retrieved Memory (1 token)')

# Annotations
ax.text(sensory_size/2, 0, f'{sensory_size}', ha='center', va='center', 
        fontsize=12, fontweight='bold')
ax.text(sensory_size + segment_size/2, 0, f'{segment_size}', ha='center', va='center', 
        fontsize=12, fontweight='bold')
ax.text(sensory_size + segment_size + memory_token/2, 0, '1', ha='center', va='center', 
        fontsize=12, fontweight='bold')

ax.set_xlim(-5, total_augmented + 5)
ax.set_ylim(-0.5, 0.5)
ax.set_xlabel('Token Position', fontsize=12)
ax.set_title('Augmented Input Structure for Segment Processing', fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=11)
ax.set_yticks([])
ax.grid(True, axis='x', alpha=0.3)

# Add arrows and labels
ax.annotate('', xy=(0, -0.35), xytext=(sensory_size, -0.35),
            arrowprops=dict(arrowstyle='<->', color='blue', lw=2))
ax.text(sensory_size/2, -0.45, 'Last k tokens\nfrom prev segment', 
        ha='center', va='top', fontsize=9, color='blue')

ax.annotate('', xy=(sensory_size, -0.35), xytext=(sensory_size + segment_size, -0.35),
            arrowprops=dict(arrowstyle='<->', color='green', lw=2))
ax.text(sensory_size + segment_size/2, -0.45, 'New input tokens', 
        ha='center', va='top', fontsize=9, color='green')

plt.tight_layout()
plt.show()

print(f"\n📝 Augmented Input Components:")
print(f"  1. Sensory Memory: {sensory_size} tokens (local continuity)")
print(f"  2. Current Segment: {segment_size} tokens (new input)")
print(f"  3. Retrieved Memory: {memory_token} pseudo-token (distant context)")
print(f"  ───────────────────────────────────────────────")
print(f"  Total augmented length: {total_augmented} tokens")
print(f"\n  The backbone processes this augmented input to generate outputs.")

---

## Part 5: Hands-On Exercises

### Exercise 1: Experiment with Different Segment Lengths

In [None]:
# TODO: Try different segment_length values (64, 128, 256)
# Observe how it affects number of segments and cache size

segment_lengths = [64, 128, 256]
results = []

sample_text = long_text  # Use the long text from earlier
sample_inputs = tokenizer(sample_text, return_tensors="pt", truncation=False)
sample_ids = sample_inputs["input_ids"].to(device)
total_tokens = sample_ids.shape[1]

print(f"Sample text length: {total_tokens} tokens\n")

for seg_len in segment_lengths:
    # Create config with different segment length
    test_config = HMTConfig(
        segment_length=seg_len,
        representation_length=seg_len // 2,
        num_memory_embeddings=50,
        sensory_memory_size=16,
        hidden_dim=768,
    )
    
    # Create HMT
    test_hmt = HMT(backbone, test_config).to(device)
    test_hmt.eval()
    
    # Process
    with torch.no_grad():
        _ = test_hmt(sample_ids, use_memory=True)
    
    stats = test_hmt.get_memory_stats()
    num_segs = (total_tokens + seg_len - 1) // seg_len
    
    results.append({
        'segment_length': seg_len,
        'num_segments': num_segs,
        'cache_size': stats['cache_size']
    })
    
    print(f"Segment Length = {seg_len}:")
    print(f"  Number of segments: {num_segs}")
    print(f"  Final cache size: {stats['cache_size']}")
    print()

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

seg_lens = [r['segment_length'] for r in results]
num_segs = [r['num_segments'] for r in results]
cache_sizes = [r['cache_size'] for r in results]

ax1.bar(range(len(seg_lens)), num_segs, tick_label=seg_lens, color='steelblue')
ax1.set_xlabel('Segment Length', fontsize=12)
ax1.set_ylabel('Number of Segments', fontsize=12)
ax1.set_title('Segments vs Segment Length', fontsize=14, fontweight='bold')
ax1.grid(True, axis='y', alpha=0.3)

ax2.bar(range(len(seg_lens)), cache_sizes, tick_label=seg_lens, color='coral')
ax2.set_xlabel('Segment Length', fontsize=12)
ax2.set_ylabel('Final Cache Size', fontsize=12)
ax2.set_title('Cache Size vs Segment Length', fontsize=14, fontweight='bold')
ax2.grid(True, axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print("\n🔍 Observation: Smaller segments → More segments → Larger cache (up to max N)")

### Exercise 2: Analyze Memory Cache Capacity

In [None]:
# TODO: Test what happens when cache exceeds max size N
# Create a very long sequence that generates more than N memory embeddings

# Your code here:
# 1. Create a config with small N (e.g., N=5)
# 2. Create a long sequence that will generate 10+ segments
# 3. Process and observe cache behavior (FIFO eviction)

# Example solution:
small_cache_config = HMTConfig(
    segment_length=32,
    num_memory_embeddings=5,  # Small cache
    sensory_memory_size=8,
    hidden_dim=768,
)

small_cache_hmt = HMT(backbone, small_cache_config).to(device)
small_cache_hmt.eval()

# Long input (should create ~10 segments with seg_len=32)
long_input = torch.randint(0, 1000, (1, 320)).to(device)  # 320 tokens = 10 segments

# Track cache size after each segment
cache_evolution = []
for i in range(0, 320, 32):
    chunk = long_input[:, i:i+32]
    with torch.no_grad():
        _ = small_cache_hmt(chunk, use_memory=True)
    cache_evolution.append(len(small_cache_hmt.memory_cache))

# Plot
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(cache_evolution)+1), cache_evolution, marker='o', linewidth=2)
plt.axhline(y=5, color='r', linestyle='--', linewidth=2, label='Max Cache Size (N=5)')
plt.xlabel('Segment Number', fontsize=12)
plt.ylabel('Cache Size', fontsize=12)
plt.title('FIFO Cache Behavior: Eviction When Exceeding Max Size', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.show()

print(f"\n✅ Cache Evolution: {cache_evolution}")
print(f"\n📚 Insight: Cache grows to N={small_cache_config.num_memory_embeddings}, then oldest memories are evicted (FIFO).")

### Exercise 3: Compare Computational Complexity

In [None]:
# TODO: Analyze complexity difference between standard transformer and HMT

sequence_lengths = [128, 256, 512, 1024, 2048]
segment_length = 128

standard_complexity = []  # O(L²)
hmt_complexity = []       # O(L) per segment

for L in sequence_lengths:
    # Standard transformer: O(L²) attention
    standard_ops = L * L
    standard_complexity.append(standard_ops)
    
    # HMT: O(segment_length) per segment + O(N) memory retrieval
    num_segments = (L + segment_length - 1) // segment_length
    hmt_ops = num_segments * segment_length * segment_length  # Attention within segments
    hmt_ops += num_segments * config.num_memory_embeddings    # Memory retrieval
    hmt_complexity.append(hmt_ops)

# Plot
plt.figure(figsize=(12, 6))
plt.plot(sequence_lengths, standard_complexity, marker='s', linewidth=3, 
         label='Standard Transformer O(L²)', color='red')
plt.plot(sequence_lengths, hmt_complexity, marker='o', linewidth=3, 
         label=f'HMT O(L) per segment (seg_len={segment_length})', color='green')

plt.xlabel('Sequence Length (L)', fontsize=12)
plt.ylabel('Approximate Operations', fontsize=12)
plt.title('Computational Complexity: Standard Transformer vs HMT', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.yscale('log')  # Log scale for better visualization
plt.tight_layout()
plt.show()

# Calculate speedup
speedups = [s / h for s, h in zip(standard_complexity, hmt_complexity)]
print("\n📊 Speedup (Standard / HMT):")
for L, speedup in zip(sequence_lengths, speedups):
    print(f"  L={L:5d}: {speedup:.2f}x faster")

print("\n💡 Key Insight: HMT's complexity scales much better for long sequences!")

---

## Summary and Key Takeaways

### What We Learned:

1. **Memory Embedding Generation** (Section 3.3):
   - Different extraction strategies: last, mean, max, cls
   - Compression layer learns optimal memory representation
   - Memory embeddings are fixed-size regardless of segment length

2. **Complete HMT Pipeline**:
   - Segmentation of long sequences into manageable chunks
   - Three-level memory hierarchy working together:
     * Sensory (local continuity)
     * Short-term (current processing)
     * Long-term (distant context)
   - FIFO cache management for bounded memory usage

3. **Computational Efficiency**:
   - HMT reduces O(L²) to O(L) per segment
   - Scales much better for very long sequences
   - Memory retrieval adds only O(N) operations

4. **Design Insights**:
   - Ablation study shows memory affects processing
   - Augmented input combines: sensory + segment + retrieved memory
   - Cache size impacts available long-term context

### Next Steps:

- **Phase 4**: Implement training with BPTT (Backpropagation Through Time)
- **Phase 5**: Fine-tune HMT on long-context tasks
- **Phase 6**: Evaluate on WikiText-103 and other benchmarks
- **Phase 7**: Scale to larger models (LLaMA, OPT-350M+)

---

### 🎓 Congratulations!

You now understand how the complete HMT system processes long sequences efficiently using hierarchical memory. The three-level architecture enables the model to maintain both local and distant context while keeping computational complexity linear.

**Paper:** [arXiv:2405.06067](https://arxiv.org/abs/2405.06067)  
**Implementation:** HMT-implementation (Phase 3.2 Complete)