# Tutorial 03: Striped Hyena Architecture

This notebook explores the Striped Hyena architecture, which forms the core of our genomic sequence modeling framework. We'll dive into the hybrid attention-convolution mechanism and its advantages for long sequence processing.

## Learning Objectives
- Understand the Striped Hyena architecture principles
- Explore hybrid attention-convolution mechanisms
- Analyze computational efficiency for long sequences
- Implement custom Hyena layers and configurations
- Compare with traditional transformer architectures

In [None]:
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))

import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn

from hyena_glt.models import HyenaGLT
from hyena_glt.models.hyena_blocks import HyenaBlock

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")

print("Striped Hyena Architecture Deep Dive")
print("====================================")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Hyena Architecture Overview

The Striped Hyena architecture combines the best of both worlds:
- **Local Processing**: Convolution layers for capturing local patterns
- **Global Context**: Attention mechanisms for long-range dependencies
- **Efficiency**: Subquadratic complexity for long sequences
- **Flexibility**: Configurable attention/convolution ratios

In [None]:
# Create a sample Hyena model
config = {
    'vocab_size': 4096,
    'hidden_size': 512,
    'num_layers': 6,
    'num_attention_heads': 8,
    'intermediate_size': 2048,
    'max_position_embeddings': 8192,
    'hyena_config': {
        'conv_kernel_size': 7,
        'conv_groups': 1,
        'attention_ratio': 0.5  # 50% attention, 50% convolution
    }
}

model = HyenaGLT(**config)
model = model.to(device)

print("=== Model Architecture Summary ===")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"Model size: {sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2:.2f} MB")
print(f"Number of layers: {config['num_layers']}")
print(f"Hidden size: {config['hidden_size']}")
print(f"Attention heads: {config['num_attention_heads']}")
print(f"Max sequence length: {config['max_position_embeddings']}")
print(f"Attention ratio: {config['hyena_config']['attention_ratio']*100:.0f}%")

## 2. Exploring Individual Hyena Components

In [None]:
# Create individual Hyena components for analysis
hidden_size = 512
num_heads = 8
sequence_length = 1024
batch_size = 2

# Create a Hyena block
hyena_block = HyenaBlock(
    hidden_size=hidden_size,
    num_attention_heads=num_heads,
    intermediate_size=2048,
    conv_kernel_size=7,
    attention_ratio=0.5
).to(device)

# Create sample input
sample_input = torch.randn(batch_size, sequence_length, hidden_size).to(device)

print("=== Hyena Block Analysis ===")
print(f"Input shape: {sample_input.shape}")

# Forward pass through Hyena block
with torch.no_grad():
    output = hyena_block(sample_input)

print(f"Output shape: {output.shape}")
print(f"Input-output shape match: {sample_input.shape == output.shape}")

# Analyze attention patterns (if available)
if hasattr(hyena_block, 'attention') and hasattr(hyena_block.attention, 'attention_weights'):
    attention_weights = hyena_block.attention.attention_weights
    print(f"Attention weights shape: {attention_weights.shape}")
    print(f"Attention weight range: [{attention_weights.min():.3f}, {attention_weights.max():.3f}]")

## 3. Computational Complexity Analysis

In [None]:
def measure_computational_complexity(model, sequence_lengths, vocab_size=4096):
    """Measure computational complexity across different sequence lengths."""
    results = []

    model.eval()

    for seq_len in sequence_lengths:
        # Create random input
        input_ids = torch.randint(0, vocab_size, (1, seq_len)).to(device)

        # Measure memory and time
        torch.cuda.empty_cache() if device.type == 'cuda' else None

        start_time = time.time()
        start_memory = torch.cuda.memory_allocated() if device.type == 'cuda' else 0

        with torch.no_grad():
            outputs = model(input_ids)

        end_time = time.time()
        peak_memory = torch.cuda.max_memory_allocated() if device.type == 'cuda' else 0

        results.append({
            'sequence_length': seq_len,
            'time': end_time - start_time,
            'memory_mb': (peak_memory - start_memory) / 1024**2,
            'tokens_per_second': seq_len / (end_time - start_time)
        })

        # Clear memory
        del input_ids, outputs
        torch.cuda.empty_cache() if device.type == 'cuda' else None

    return results

# Test different sequence lengths
sequence_lengths = [128, 256, 512, 1024, 2048, 4096]
if device.type == 'cpu':
    sequence_lengths = [128, 256, 512, 1024]  # Smaller for CPU

print("=== Computational Complexity Analysis ===")
print("Measuring performance across different sequence lengths...")

complexity_results = measure_computational_complexity(model, sequence_lengths)

# Display results
print("\nResults:")
print(f"{'Seq Len':<8} {'Time (s)':<10} {'Memory (MB)':<12} {'Tokens/sec':<12}")
print("-" * 50)

for result in complexity_results:
    print(f"{result['sequence_length']:<8} "
          f"{result['time']:<10.3f} "
          f"{result['memory_mb']:<12.1f} "
          f"{result['tokens_per_second']:<12.0f}")

## 4. Visualizing Complexity Scaling

In [None]:
# Extract data for plotting
seq_lengths = [r['sequence_length'] for r in complexity_results]
times = [r['time'] for r in complexity_results]
memories = [r['memory_mb'] for r in complexity_results]
throughputs = [r['tokens_per_second'] for r in complexity_results]

# Create comprehensive visualization
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Hyena-GLT Computational Complexity Analysis', fontsize=16, fontweight='bold')

# Time complexity
ax1.loglog(seq_lengths, times, 'bo-', linewidth=2, markersize=8)
ax1.set_xlabel('Sequence Length')
ax1.set_ylabel('Processing Time (seconds)')
ax1.set_title('Time Complexity')
ax1.grid(True, alpha=0.3)

# Fit and plot theoretical complexity lines
if len(seq_lengths) > 2:
    # Quadratic reference (O(n²))
    quad_ref = [times[0] * (s/seq_lengths[0])**2 for s in seq_lengths]
    ax1.loglog(seq_lengths, quad_ref, 'r--', alpha=0.7, label='O(n²) reference')

    # Linear reference (O(n))
    linear_ref = [times[0] * (s/seq_lengths[0]) for s in seq_lengths]
    ax1.loglog(seq_lengths, linear_ref, 'g--', alpha=0.7, label='O(n) reference')

    ax1.legend()

# Memory complexity
ax2.loglog(seq_lengths, memories, 'ro-', linewidth=2, markersize=8)
ax2.set_xlabel('Sequence Length')
ax2.set_ylabel('Peak Memory (MB)')
ax2.set_title('Memory Complexity')
ax2.grid(True, alpha=0.3)

# Throughput analysis
ax3.semilogx(seq_lengths, throughputs, 'go-', linewidth=2, markersize=8)
ax3.set_xlabel('Sequence Length')
ax3.set_ylabel('Throughput (tokens/second)')
ax3.set_title('Processing Throughput')
ax3.grid(True, alpha=0.3)

# Efficiency ratio (throughput per memory)
efficiency = [t/m if m > 0 else 0 for t, m in zip(throughputs, memories, strict=False)]
ax4.semilogx(seq_lengths, efficiency, 'mo-', linewidth=2, markersize=8)
ax4.set_xlabel('Sequence Length')
ax4.set_ylabel('Efficiency (tokens/sec/MB)')
ax4.set_title('Memory Efficiency')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Calculate complexity growth rates
if len(complexity_results) > 1:
    time_growth_rate = np.log(times[-1]/times[0]) / np.log(seq_lengths[-1]/seq_lengths[0])
    memory_growth_rate = np.log(memories[-1]/memories[0]) / np.log(seq_lengths[-1]/seq_lengths[0])

    print("\n=== Complexity Growth Analysis ===")
    print(f"Time complexity growth rate: O(n^{time_growth_rate:.2f})")
    print(f"Memory complexity growth rate: O(n^{memory_growth_rate:.2f})")
    print("Theoretical quadratic (transformer): O(n^2.00)")
    print(f"Efficiency improvement over quadratic: {2.0/time_growth_rate:.2f}x")

## 5. Attention vs Convolution Analysis

In [None]:
def compare_attention_ratios(attention_ratios, sequence_length=1024):
    """Compare models with different attention/convolution ratios."""
    results = []

    base_config = {
        'vocab_size': 1024,
        'hidden_size': 256,  # Smaller for faster comparison
        'num_layers': 4,
        'num_attention_heads': 4,
        'intermediate_size': 1024,
        'max_position_embeddings': 4096
    }

    for attention_ratio in attention_ratios:
        config = base_config.copy()
        config['hyena_config'] = {
            'conv_kernel_size': 7,
            'conv_groups': 1,
            'attention_ratio': attention_ratio
        }

        # Create model
        test_model = HyenaGLT(**config).to(device)
        test_model.eval()

        # Test performance
        input_ids = torch.randint(0, config['vocab_size'], (1, sequence_length)).to(device)

        # Measure time and memory
        torch.cuda.empty_cache() if device.type == 'cuda' else None

        start_time = time.time()
        start_memory = torch.cuda.memory_allocated() if device.type == 'cuda' else 0

        with torch.no_grad():
            outputs = test_model(input_ids)

        end_time = time.time()
        peak_memory = torch.cuda.max_memory_allocated() if device.type == 'cuda' else 0

        results.append({
            'attention_ratio': attention_ratio,
            'conv_ratio': 1 - attention_ratio,
            'time': end_time - start_time,
            'memory_mb': (peak_memory - start_memory) / 1024**2,
            'parameters': sum(p.numel() for p in test_model.parameters()),
            'throughput': sequence_length / (end_time - start_time)
        })

        # Cleanup
        del test_model, input_ids, outputs
        torch.cuda.empty_cache() if device.type == 'cuda' else None

    return results

# Test different attention ratios
attention_ratios = [0.0, 0.25, 0.5, 0.75, 1.0]

print("=== Attention vs Convolution Ratio Analysis ===")
print("Comparing models with different attention/convolution ratios...")

ratio_results = compare_attention_ratios(attention_ratios)

# Display results
print("\nResults:")
print(f"{'Attn%':<6} {'Conv%':<6} {'Time(s)':<8} {'Memory(MB)':<10} {'Params':<8} {'Throughput':<10}")
print("-" * 65)

for result in ratio_results:
    print(f"{result['attention_ratio']*100:<6.0f} "
          f"{result['conv_ratio']*100:<6.0f} "
          f"{result['time']:<8.3f} "
          f"{result['memory_mb']:<10.1f} "
          f"{result['parameters']/1000:<8.0f}k "
          f"{result['throughput']:<10.0f}")

## 6. Visualizing Attention vs Convolution Trade-offs

In [None]:
# Extract data for visualization
attn_ratios = [r['attention_ratio'] for r in ratio_results]
conv_ratios = [r['conv_ratio'] for r in ratio_results]
ratio_times = [r['time'] for r in ratio_results]
ratio_memories = [r['memory_mb'] for r in ratio_results]
ratio_throughputs = [r['throughput'] for r in ratio_results]
ratio_params = [r['parameters'] for r in ratio_results]

# Create visualization
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Attention vs Convolution Trade-off Analysis', fontsize=16, fontweight='bold')

# Performance vs attention ratio
ax1.plot(attn_ratios, ratio_times, 'bo-', linewidth=2, markersize=8, label='Processing Time')
ax1_twin = ax1.twinx()
ax1_twin.plot(attn_ratios, ratio_throughputs, 'ro-', linewidth=2, markersize=8, label='Throughput')
ax1.set_xlabel('Attention Ratio')
ax1.set_ylabel('Processing Time (s)', color='blue')
ax1_twin.set_ylabel('Throughput (tokens/s)', color='red')
ax1.set_title('Performance vs Attention Ratio')
ax1.grid(True, alpha=0.3)

# Memory usage
ax2.plot(attn_ratios, ratio_memories, 'go-', linewidth=2, markersize=8)
ax2.set_xlabel('Attention Ratio')
ax2.set_ylabel('Memory Usage (MB)')
ax2.set_title('Memory vs Attention Ratio')
ax2.grid(True, alpha=0.3)

# Parameter count
ax3.plot(attn_ratios, [p/1000 for p in ratio_params], 'mo-', linewidth=2, markersize=8)
ax3.set_xlabel('Attention Ratio')
ax3.set_ylabel('Parameters (thousands)')
ax3.set_title('Model Size vs Attention Ratio')
ax3.grid(True, alpha=0.3)

# Efficiency comparison (throughput per parameter)
efficiency_per_param = [t/p for t, p in zip(ratio_throughputs, ratio_params, strict=False)]
ax4.plot(attn_ratios, efficiency_per_param, 'co-', linewidth=2, markersize=8)
ax4.set_xlabel('Attention Ratio')
ax4.set_ylabel('Efficiency (tokens/s/param)')
ax4.set_title('Parameter Efficiency')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Find optimal attention ratio
best_throughput_idx = np.argmax(ratio_throughputs)
best_memory_idx = np.argmin(ratio_memories)
best_efficiency_idx = np.argmax(efficiency_per_param)

print("\n=== Optimal Configuration Analysis ===")
print(f"Best throughput: {attn_ratios[best_throughput_idx]*100:.0f}% attention ({ratio_throughputs[best_throughput_idx]:.0f} tokens/s)")
print(f"Best memory efficiency: {attn_ratios[best_memory_idx]*100:.0f}% attention ({ratio_memories[best_memory_idx]:.1f} MB)")
print(f"Best parameter efficiency: {attn_ratios[best_efficiency_idx]*100:.0f}% attention ({efficiency_per_param[best_efficiency_idx]:.2e} tokens/s/param)")

# Calculate trade-offs
pure_attention_perf = ratio_throughputs[-1]  # 100% attention
pure_conv_perf = ratio_throughputs[0]        # 0% attention (100% conv)
hybrid_perf = ratio_throughputs[len(ratio_throughputs)//2]  # 50% attention

print("\n=== Architecture Comparison ===")
print(f"Pure convolution (0% attention): {pure_conv_perf:.0f} tokens/s")
print(f"Hybrid (50% attention): {hybrid_perf:.0f} tokens/s")
print(f"Pure attention (100% attention): {pure_attention_perf:.0f} tokens/s")
print(f"Hybrid advantage over pure conv: {(hybrid_perf/pure_conv_perf-1)*100:.1f}%")
print(f"Hybrid advantage over pure attention: {(hybrid_perf/pure_attention_perf-1)*100:.1f}%")

## 7. Custom Hyena Layer Implementation

In [None]:
# Implement a custom Hyena layer with genomic-specific optimizations
class GenomicHyenaLayer(nn.Module):
    """Custom Hyena layer optimized for genomic sequences."""

    def __init__(self, hidden_size, num_heads, conv_kernel_size=7,
                 attention_ratio=0.5, genomic_conv_patterns=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.attention_ratio = attention_ratio

        # Split channels between attention and convolution
        self.attention_dim = int(hidden_size * attention_ratio)
        self.conv_dim = hidden_size - self.attention_dim

        # Attention component
        if self.attention_dim > 0:
            self.attention = nn.MultiheadAttention(
                embed_dim=self.attention_dim,
                num_heads=min(num_heads, self.attention_dim // 64),
                batch_first=True
            )

        # Convolution component with genomic patterns
        if self.conv_dim > 0:
            if genomic_conv_patterns:
                # Use multiple kernel sizes for different genomic patterns
                self.conv_layers = nn.ModuleList([
                    nn.Conv1d(self.conv_dim, self.conv_dim//3, kernel_size=3, padding=1, groups=1),  # Codons
                    nn.Conv1d(self.conv_dim, self.conv_dim//3, kernel_size=7, padding=3, groups=1),  # Local motifs
                    nn.Conv1d(self.conv_dim, self.conv_dim//3, kernel_size=15, padding=7, groups=1), # Longer patterns
                ])
            else:
                self.conv_layers = nn.ModuleList([
                    nn.Conv1d(self.conv_dim, self.conv_dim, kernel_size=conv_kernel_size,
                             padding=conv_kernel_size//2, groups=1)
                ])

        # Layer normalization
        self.layer_norm = nn.LayerNorm(hidden_size)

        # Output projection
        self.output_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        residual = x

        outputs = []

        # Attention pathway
        if self.attention_dim > 0:
            attn_input = x[:, :, :self.attention_dim]
            attn_output, _ = self.attention(attn_input, attn_input, attn_input)
            outputs.append(attn_output)

        # Convolution pathway
        if self.conv_dim > 0:
            conv_input = x[:, :, self.attention_dim:].transpose(1, 2)  # (B, C, L)

            conv_outputs = []
            for conv_layer in self.conv_layers:
                conv_out = conv_layer(conv_input)
                conv_outputs.append(conv_out)

            # Concatenate conv outputs
            conv_output = torch.cat(conv_outputs, dim=1).transpose(1, 2)  # (B, L, C)
            outputs.append(conv_output)

        # Combine pathways
        if len(outputs) > 1:
            combined = torch.cat(outputs, dim=-1)
        else:
            combined = outputs[0]

        # Apply output projection and residual connection
        output = self.output_proj(combined)
        output = self.layer_norm(output + residual)

        return output

# Test custom genomic Hyena layer
print("=== Custom Genomic Hyena Layer Test ===")

# Create custom layer
custom_layer = GenomicHyenaLayer(
    hidden_size=512,
    num_heads=8,
    conv_kernel_size=7,
    attention_ratio=0.5,
    genomic_conv_patterns=True
).to(device)

# Test input
test_input = torch.randn(2, 1024, 512).to(device)

print(f"Custom layer parameters: {sum(p.numel() for p in custom_layer.parameters()):,}")
print(f"Input shape: {test_input.shape}")

# Forward pass
with torch.no_grad():
    custom_output = custom_layer(test_input)

print(f"Output shape: {custom_output.shape}")
print("Output statistics:")
print(f"  Mean: {custom_output.mean().item():.6f}")
print(f"  Std: {custom_output.std().item():.6f}")
print(f"  Min: {custom_output.min().item():.6f}")
print(f"  Max: {custom_output.max().item():.6f}")

# Compare with standard layer
standard_layer = HyenaBlock(
    hidden_size=512,
    num_attention_heads=8,
    intermediate_size=2048,
    conv_kernel_size=7,
    attention_ratio=0.5
).to(device)

with torch.no_grad():
    standard_output = standard_layer(test_input)

print("\nComparison with standard Hyena layer:")
print(f"Custom layer params: {sum(p.numel() for p in custom_layer.parameters()):,}")
print(f"Standard layer params: {sum(p.numel() for p in standard_layer.parameters()):,}")
print(f"Output similarity (cosine): {torch.nn.functional.cosine_similarity(custom_output.flatten(), standard_output.flatten(), dim=0).item():.4f}")

## 8. Architecture Recommendations

In [None]:
print("=== Striped Hyena Architecture Recommendations ===")
print()

recommendations = {
    "Sequence Length Considerations": [
        "Short sequences (<1K): Higher attention ratio (70-80%) for global context",
        "Medium sequences (1K-10K): Balanced ratio (40-60%) for optimal performance",
        "Long sequences (>10K): Lower attention ratio (20-40%) for efficiency",
        "Very long sequences (>100K): Consider pure convolution or sliding windows"
    ],
    "Genomic Task Optimization": [
        "Classification: Higher attention ratio for global sequence understanding",
        "Local motif detection: Higher convolution ratio with multiple kernel sizes",
        "Sequence generation: Balanced ratio with autoregressive masking",
        "Protein folding: Custom kernels matching structural patterns"
    ],
    "Performance Tuning": [
        "Use gradient checkpointing for very long sequences",
        "Implement mixed precision training for memory efficiency",
        "Consider model parallelism for large models",
        "Profile attention vs convolution ratios for your specific task"
    ],
    "Implementation Best Practices": [
        "Start with 50% attention ratio and tune based on validation performance",
        "Use multiple convolution kernel sizes for genomic pattern diversity",
        "Implement proper normalization between attention and conv pathways",
        "Monitor memory usage and adjust batch sizes accordingly"
    ]
}

for category, practices in recommendations.items():
    print(f"**{category}:**")
    for practice in practices:
        print(f"  • {practice}")
    print()

# Configuration templates
print("=== Recommended Configurations ===")
print()

configs = {
    "Small Model (Fast Inference)": {
        'hidden_size': 256,
        'num_layers': 6,
        'num_attention_heads': 4,
        'attention_ratio': 0.5,
        'max_sequence_length': 4096
    },
    "Medium Model (Balanced)": {
        'hidden_size': 512,
        'num_layers': 12,
        'num_attention_heads': 8,
        'attention_ratio': 0.4,
        'max_sequence_length': 8192
    },
    "Large Model (High Accuracy)": {
        'hidden_size': 1024,
        'num_layers': 24,
        'num_attention_heads': 16,
        'attention_ratio': 0.3,
        'max_sequence_length': 16384
    }
}

for name, config in configs.items():
    print(f"**{name}:**")
    for key, value in config.items():
        print(f"  {key}: {value}")
    print()

# Summary from this tutorial
print("=== Tutorial Summary ===")
print(f"Architecture components analyzed: {len(['Hyena Block', 'Attention', 'Convolution', 'Custom Layer'])}")
print(f"Complexity measurements: {len(complexity_results)} sequence lengths tested")
print(f"Attention ratios compared: {len(ratio_results)}")
print("Custom implementations: 1 genomic-optimized layer")
print("Visualization plots: 8 comprehensive analyses")

## Conclusion

This tutorial provided a comprehensive analysis of the Striped Hyena architecture. Key insights:

1. **Hybrid Architecture**: Combines the efficiency of convolution with the expressiveness of attention
2. **Scalable Complexity**: Achieves better-than-quadratic scaling for long sequences
3. **Configurable Trade-offs**: Attention ratio can be tuned based on task requirements
4. **Genomic Optimization**: Custom implementations can leverage domain-specific patterns
5. **Performance Characteristics**: Memory and computational efficiency improve with lower attention ratios

In the next tutorial, we'll explore model training strategies and optimization techniques for genomic tasks.