# Tutorial 02: Tokenization Deep Dive

This notebook provides an in-depth exploration of the tokenization methods used in the Hyena-GLT framework, focusing on BLT (Byte Latent Transformer) tokenization and specialized genomic tokenizers.

## Learning Objectives
- Understand BLT tokenization principles and advantages
- Compare different genomic tokenizers (DNA, RNA, Protein)
- Explore tokenization strategies for different sequence types
- Analyze tokenization efficiency and compression ratios
- Implement custom tokenization workflows

In [None]:
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))

import time
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from hyena_glt.tokenizers import (
    BLTTokenizer,
    DNATokenizer,
    ProteinTokenizer,
    RNATokenizer,
)

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")

print("Hyena-GLT Tokenization Deep Dive Tutorial")
print("=========================================")

## 1. BLT Tokenization Overview

BLT (Byte Latent Transformer) tokenization offers several advantages for genomic sequences:
- **Efficiency**: Operates directly on byte representations
- **Universality**: Handles any sequence type without vocabulary limitations
- **Compression**: Learns optimal byte-level representations
- **Flexibility**: Adaptable to different genomic data types

In [None]:
# Initialize BLT tokenizer
blt_tokenizer = BLTTokenizer(vocab_size=8192, latent_dim=256)

# Create sample genomic sequences
dna_sequences = [
    "ATCGATCGATCGATCGATCGATCGATCG",
    "GCTAGCTAGCTAGCTAGCTAGCTAGCTA",
    "TTAACCGGTTAACCGGTTAACCGGTTAA",
    "CGATCGATCGATCGATCGATCGATCGAT"
]

print("Sample DNA sequences:")
for i, seq in enumerate(dna_sequences[:2]):
    print(f"Sequence {i+1}: {seq}")

# Demonstrate BLT tokenization
print("\n=== BLT Tokenization ===")
for i, sequence in enumerate(dna_sequences[:2]):
    # Encode sequence
    tokens = blt_tokenizer.encode(sequence)

    # Decode back to verify
    decoded = blt_tokenizer.decode(tokens)

    print(f"\nSequence {i+1}:")
    print(f"Original:  {sequence}")
    print(f"Tokens:    {tokens[:10]}... ({len(tokens)} total)")
    print(f"Decoded:   {decoded}")
    print(f"Match:     {sequence == decoded}")
    print(f"Compression: {len(sequence)}/{len(tokens)} = {len(sequence)/len(tokens):.2f}x")

## 2. Specialized Genomic Tokenizers

Different genomic sequence types benefit from specialized tokenization approaches:

In [None]:
# Initialize specialized tokenizers
dna_tokenizer = DNATokenizer()
rna_tokenizer = RNATokenizer()
protein_tokenizer = ProteinTokenizer()

# Sample sequences for each type
sample_sequences = {
    'DNA': "ATCGATCGATCGATCGATCGATCGATCG",
    'RNA': "AUCGAUCGAUCGAUCGAUCGAUCGAUCG",
    'Protein': "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
}

tokenizers = {
    'DNA': dna_tokenizer,
    'RNA': rna_tokenizer,
    'Protein': protein_tokenizer
}

print("=== Specialized Tokenizer Comparison ===")
tokenization_results = {}

for seq_type, sequence in sample_sequences.items():
    tokenizer = tokenizers[seq_type]

    # Tokenize
    tokens = tokenizer.encode(sequence)
    decoded = tokenizer.decode(tokens)

    # Store results
    tokenization_results[seq_type] = {
        'original_length': len(sequence),
        'token_count': len(tokens),
        'vocab_size': tokenizer.vocab_size if hasattr(tokenizer, 'vocab_size') else 'N/A',
        'compression_ratio': len(sequence) / len(tokens)
    }

    print(f"\n{seq_type} Tokenization:")
    print(f"  Original:  {sequence[:50]}{'...' if len(sequence) > 50 else ''}")
    print(f"  Tokens:    {tokens[:10]}... ({len(tokens)} total)")
    print(f"  Vocab size: {tokenization_results[seq_type]['vocab_size']}")
    print(f"  Compression: {tokenization_results[seq_type]['compression_ratio']:.2f}x")
    print(f"  Perfect reconstruction: {sequence == decoded}")

## 3. Tokenization Performance Analysis

In [None]:
# Performance comparison
def benchmark_tokenizer(tokenizer, sequences, tokenizer_name):
    """Benchmark tokenization performance."""
    start_time = time.time()

    total_tokens = 0
    total_chars = 0

    for sequence in sequences:
        tokens = tokenizer.encode(sequence)
        total_tokens += len(tokens)
        total_chars += len(sequence)

    end_time = time.time()

    return {
        'name': tokenizer_name,
        'time': end_time - start_time,
        'total_chars': total_chars,
        'total_tokens': total_tokens,
        'compression_ratio': total_chars / total_tokens,
        'chars_per_second': total_chars / (end_time - start_time)
    }

# Generate longer test sequences
test_sequences = []
for _ in range(100):
    seq = ''.join(np.random.choice(['A', 'T', 'C', 'G'], size=1000))
    test_sequences.append(seq)

# Benchmark all tokenizers
benchmark_results = []
benchmark_results.append(benchmark_tokenizer(blt_tokenizer, test_sequences, 'BLT'))
benchmark_results.append(benchmark_tokenizer(dna_tokenizer, test_sequences, 'DNA'))

print("=== Tokenization Performance Benchmark ===")
print(f"Test data: {len(test_sequences)} sequences, {benchmark_results[0]['total_chars']:,} total characters")
print()

for result in benchmark_results:
    print(f"{result['name']} Tokenizer:")
    print(f"  Time: {result['time']:.3f}s")
    print(f"  Tokens generated: {result['total_tokens']:,}")
    print(f"  Compression ratio: {result['compression_ratio']:.2f}x")
    print(f"  Speed: {result['chars_per_second']:,.0f} chars/sec")
    print()

## 4. Visualization of Tokenization Patterns

In [None]:
# Analyze token distribution patterns
def analyze_token_patterns(tokenizer, sequences, tokenizer_name):
    """Analyze and visualize token distribution patterns."""
    all_tokens = []

    for sequence in sequences:
        tokens = tokenizer.encode(sequence)
        all_tokens.extend(tokens)

    # Count token frequencies
    token_counts = Counter(all_tokens)

    return {
        'name': tokenizer_name,
        'tokens': all_tokens,
        'unique_tokens': len(token_counts),
        'most_common': token_counts.most_common(10),
        'token_counts': token_counts
    }

# Analyze patterns for different tokenizers
pattern_analyses = []
pattern_analyses.append(analyze_token_patterns(blt_tokenizer, test_sequences[:10], 'BLT'))
pattern_analyses.append(analyze_token_patterns(dna_tokenizer, test_sequences[:10], 'DNA'))

# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Tokenization Pattern Analysis', fontsize=16, fontweight='bold')

# Token frequency distributions
for i, analysis in enumerate(pattern_analyses):
    # Most common tokens
    tokens, counts = zip(*analysis['most_common'], strict=False)
    axes[i, 0].bar(range(len(tokens)), counts)
    axes[i, 0].set_title(f'{analysis["name"]} - Most Common Tokens')
    axes[i, 0].set_xlabel('Token Rank')
    axes[i, 0].set_ylabel('Frequency')

    # Token distribution histogram
    all_counts = list(analysis['token_counts'].values())
    axes[i, 1].hist(all_counts, bins=20, alpha=0.7)
    axes[i, 1].set_title(f'{analysis["name"]} - Token Frequency Distribution')
    axes[i, 1].set_xlabel('Token Frequency')
    axes[i, 1].set_ylabel('Number of Tokens')
    axes[i, 1].set_yscale('log')

plt.tight_layout()
plt.show()

# Print summary statistics
print("=== Token Pattern Analysis Summary ===")
for analysis in pattern_analyses:
    print(f"\n{analysis['name']} Tokenizer:")
    print(f"  Total tokens: {len(analysis['tokens']):,}")
    print(f"  Unique tokens: {analysis['unique_tokens']:,}")
    print(f"  Token diversity: {analysis['unique_tokens']/len(analysis['tokens']):.3f}")
    print(f"  Most common token appears {analysis['most_common'][0][1]} times")

## 5. Advanced Tokenization Strategies

In [None]:
# Demonstrate context-aware tokenization
def sliding_window_tokenization(tokenizer, sequence, window_size=100, overlap=20):
    """Tokenize sequence using sliding windows for long sequences."""
    windows = []
    token_windows = []

    for start in range(0, len(sequence) - overlap, window_size - overlap):
        end = min(start + window_size, len(sequence))
        window = sequence[start:end]
        tokens = tokenizer.encode(window)

        windows.append(window)
        token_windows.append(tokens)

    return windows, token_windows

# Create a long test sequence
long_sequence = ''.join(np.random.choice(['A', 'T', 'C', 'G'], size=2000))

print("=== Sliding Window Tokenization ===")
print(f"Long sequence length: {len(long_sequence)} characters")

# Apply sliding window tokenization
windows, token_windows = sliding_window_tokenization(blt_tokenizer, long_sequence)

print(f"Generated {len(windows)} windows")
print(f"Window sizes: {[len(w) for w in windows[:5]]}...")
print(f"Token counts per window: {[len(tw) for tw in token_windows[:5]]}...")

# Visualize tokenization across windows
window_lengths = [len(w) for w in windows]
token_counts = [len(tw) for tw in token_windows]
compression_ratios = [wl/tc for wl, tc in zip(window_lengths, token_counts, strict=False)]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(range(len(windows)), window_lengths, 'b-', label='Characters', alpha=0.7)
ax1.plot(range(len(windows)), token_counts, 'r-', label='Tokens', alpha=0.7)
ax1.set_xlabel('Window Index')
ax1.set_ylabel('Count')
ax1.set_title('Characters vs Tokens per Window')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(range(len(windows)), compression_ratios, 'g-', alpha=0.7)
ax2.set_xlabel('Window Index')
ax2.set_ylabel('Compression Ratio')
ax2.set_title('Compression Ratio per Window')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Average compression ratio: {np.mean(compression_ratios):.2f}x")
print(f"Compression ratio std: {np.std(compression_ratios):.3f}")

## 6. Custom Tokenization Workflows

In [None]:
# Implement a custom tokenization workflow
class AdaptiveTokenizer:
    """Adaptive tokenizer that chooses the best tokenizer based on sequence characteristics."""

    def __init__(self):
        self.tokenizers = {
            'blt': BLTTokenizer(vocab_size=4096),
            'dna': DNATokenizer(),
            'rna': RNATokenizer(),
            'protein': ProteinTokenizer()
        }

    def detect_sequence_type(self, sequence):
        """Detect the most likely sequence type."""
        sequence = sequence.upper()

        # Count character types
        dna_chars = set('ATCG')
        rna_chars = set('AUCG')
        protein_chars = set('ACDEFGHIKLMNPQRSTVWY')

        seq_chars = set(sequence)

        # Calculate overlap scores
        dna_score = len(seq_chars.intersection(dna_chars)) / len(seq_chars) if seq_chars else 0
        rna_score = len(seq_chars.intersection(rna_chars)) / len(seq_chars) if seq_chars else 0
        protein_score = len(seq_chars.intersection(protein_chars)) / len(seq_chars) if seq_chars else 0

        # Specific checks
        has_T = 'T' in seq_chars
        has_U = 'U' in seq_chars

        if has_U and not has_T and rna_score > 0.9:
            return 'rna'
        elif has_T and not has_U and dna_score > 0.9:
            return 'dna'
        elif protein_score > 0.8 and len(seq_chars) > 4:
            return 'protein'
        else:
            return 'blt'  # Default to BLT for ambiguous cases

    def encode(self, sequence):
        """Adaptively encode sequence using the most appropriate tokenizer."""
        seq_type = self.detect_sequence_type(sequence)
        tokenizer = self.tokenizers[seq_type]
        return tokenizer.encode(sequence), seq_type

    def decode(self, tokens, seq_type):
        """Decode tokens using the specified tokenizer type."""
        tokenizer = self.tokenizers[seq_type]
        return tokenizer.decode(tokens)

# Test adaptive tokenizer
adaptive_tokenizer = AdaptiveTokenizer()

test_sequences_adaptive = {
    'DNA': "ATCGATCGATCGATCG",
    'RNA': "AUCGAUCGAUCGAUCG",
    'Protein': "MKTVRQERLKSIVRIL",
    'Mixed': "ATCGXYZQWERTY123"
}

print("=== Adaptive Tokenization Results ===")
for label, sequence in test_sequences_adaptive.items():
    tokens, detected_type = adaptive_tokenizer.encode(sequence)
    decoded = adaptive_tokenizer.decode(tokens, detected_type)

    print(f"\n{label} sequence:")
    print(f"  Input: {sequence}")
    print(f"  Detected type: {detected_type}")
    print(f"  Tokens: {len(tokens)} tokens")
    print(f"  Decoded: {decoded}")
    print(f"  Perfect match: {sequence == decoded}")

## 7. Best Practices and Recommendations

In [None]:
print("=== Tokenization Best Practices ===")
print()

recommendations = {
    "Sequence Type Selection": [
        "Use DNATokenizer for pure DNA sequences (A, T, C, G only)",
        "Use RNATokenizer for RNA sequences (A, U, C, G)",
        "Use ProteinTokenizer for amino acid sequences",
        "Use BLTTokenizer for mixed or unknown sequence types"
    ],
    "Performance Optimization": [
        "Batch process multiple sequences for better efficiency",
        "Use sliding windows for very long sequences (>10k bp)",
        "Consider caching tokenized sequences for repeated use",
        "Profile different tokenizers for your specific use case"
    ],
    "Memory Management": [
        "Process sequences in chunks for large datasets",
        "Clear tokenizer caches periodically",
        "Use appropriate vocab_size for BLT tokenizers",
        "Monitor memory usage during batch processing"
    ],
    "Quality Assurance": [
        "Always verify encode/decode round-trip accuracy",
        "Test tokenizers on representative data samples",
        "Monitor compression ratios for efficiency",
        "Validate tokenizer output before model training"
    ]
}

for category, practices in recommendations.items():
    print(f"**{category}:**")
    for practice in practices:
        print(f"  • {practice}")
    print()

# Summary statistics from this tutorial
print("=== Tutorial Summary ===")
print(f"Tokenizers demonstrated: {len(tokenizers) + 1} (including BLT)")
print(f"Test sequences processed: {len(test_sequences) + len(test_sequences_adaptive)}")
print(f"Performance benchmarks: {len(benchmark_results)}")
print("Visualization plots: 4")
print("Advanced techniques: Sliding windows, Adaptive tokenization")

## Conclusion

This tutorial provided a comprehensive exploration of tokenization in the Hyena-GLT framework. Key takeaways:

1. **BLT Tokenization** offers universal applicability and good compression
2. **Specialized Tokenizers** provide optimized performance for specific sequence types
3. **Performance Analysis** helps choose the right tokenizer for your use case
4. **Advanced Strategies** like sliding windows handle long sequences effectively
5. **Adaptive Approaches** can automatically select the best tokenization method

In the next tutorial, we'll explore the Striped Hyena architecture and how it processes tokenized genomic sequences.