# Introduction to Hyena-GLT: Genomic Language Transformer

Welcome to Hyena-GLT! This notebook provides a comprehensive introduction to the framework that combines BLT's byte latent tokenization with Savanna's Striped Hyena blocks for efficient genomic sequence modeling.

## üéØ Learning Objectives

By the end of this notebook, you will:
- Understand the Hyena-GLT architecture
- Learn how to process genomic data
- Train your first genomic model
- Evaluate model performance
- Apply the model to real genomic tasks

## üìã Prerequisites

- Basic Python knowledge
- Understanding of genomic sequences (DNA/RNA/protein)
- Familiarity with machine learning concepts

## 1. Installation and Setup

First, let's install and import the necessary packages:

In [None]:
# Install required packages (uncomment if needed)
# !pip install torch transformers numpy pandas matplotlib seaborn
# !pip install biopython scikit-learn plotly

import sys
import os

# Add the project root to Python path
project_root = os.path.abspath('../..')
if project_root not in sys.path:
    sys.path.append(project_root)

# Core imports
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Hyena-GLT imports
from hyena_glt.config import HyenaGLTConfig
from hyena_glt.data import (
    DNATokenizer, RNATokenizer, ProteinTokenizer,
    GenomicDataset, GenomicUtilities
)
from hyena_glt.model import HyenaGLT
from hyena_glt.training import HyenaGLTTrainer, TrainingConfig
from hyena_glt.evaluation import GenomicMetrics, ModelAnalyzer

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("üß¨ Hyena-GLT setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 2. Understanding Genomic Data

Let's start by exploring different types of genomic sequences and how Hyena-GLT processes them.

In [None]:
# Sample genomic sequences
dna_sequence = "ATCGATCGTAGCTAGCTAGCGATCGATCGTAGCTAGC"
rna_sequence = "AUCGAUCGUAGCUAGCUAGCGAUCGAUCGUAGCUAGC"
protein_sequence = "MKTVRQERLKSIVRILKESSKGRPPPQDVTAKRAEQFVDQAQIILEQPKQRGFRFR"

print("Sample Sequences:")
print(f"DNA:     {dna_sequence}")
print(f"RNA:     {rna_sequence}")
print(f"Protein: {protein_sequence}")

# Initialize tokenizers
dna_tokenizer = DNATokenizer()
rna_tokenizer = RNATokenizer()
protein_tokenizer = ProteinTokenizer()

print("\nTokenizer Vocabularies:")
print(f"DNA vocab size:     {dna_tokenizer.vocab_size}")
print(f"RNA vocab size:     {rna_tokenizer.vocab_size}")
print(f"Protein vocab size: {protein_tokenizer.vocab_size}")

In [None]:
# Tokenize sequences
dna_tokens = dna_tokenizer.encode(dna_sequence)
rna_tokens = rna_tokenizer.encode(rna_sequence)
protein_tokens = protein_tokenizer.encode(protein_sequence)

print("Tokenized Sequences:")
print(f"DNA tokens:     {dna_tokens[:10]}... (length: {len(dna_tokens)})")
print(f"RNA tokens:     {rna_tokens[:10]}... (length: {len(rna_tokens)})")
print(f"Protein tokens: {protein_tokens[:10]}... (length: {len(protein_tokens)})")

# Demonstrate decoding
decoded_dna = dna_tokenizer.decode(dna_tokens)
print(f"\nDecoded DNA: {decoded_dna}")
print(f"Original DNA: {dna_sequence}")
print(f"Match: {decoded_dna == dna_sequence}")

## 3. Model Configuration

Hyena-GLT uses a comprehensive configuration system. Let's explore different configurations for various genomic tasks.

In [None]:
# Create configurations for different tasks
configs = {}

# DNA sequence classification
configs['dna_classification'] = HyenaGLTConfig.for_dna_classification(
    num_classes=5,  # e.g., promoter, enhancer, intron, exon, intergenic
    max_length=1024,
    hidden_size=256,
    num_layers=6
)

# Protein function prediction
configs['protein_function'] = HyenaGLTConfig.for_protein_function(
    num_functions=100,  # number of GO terms
    max_length=512,
    hidden_size=384,
    num_layers=8
)

# RNA secondary structure
configs['rna_structure'] = HyenaGLTConfig.for_rna_structure(
    max_length=256,
    hidden_size=256,
    num_layers=6
)

# Display configuration details
for task, config in configs.items():
    print(f"\n{task.upper()} Configuration:")
    print(f"  Sequence type: {config.sequence_type}")
    print(f"  Task type: {config.task_type}")
    print(f"  Max length: {config.max_length}")
    print(f"  Hidden size: {config.hidden_size}")
    print(f"  Layers: {config.num_layers}")
    print(f"  Hyena order: {config.hyena_order}")

## 4. Model Architecture Overview

Let's create and examine the Hyena-GLT model architecture.

In [None]:
# Create a model for DNA classification
config = configs['dna_classification']
model = HyenaGLT(config)

# Model summary
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"Model Parameters: {total_params:,}")

# Model architecture breakdown
print("\nModel Components:")
for name, module in model.named_children():
    params = sum(p.numel() for p in module.parameters())
    print(f"  {name}: {params:,} parameters")

# Test forward pass
batch_size = 4
seq_length = 128
sample_input = torch.randint(0, config.vocab_size, (batch_size, seq_length))

with torch.no_grad():
    output = model(sample_input)
    
print(f"\nSample Forward Pass:")
print(f"Input shape: {sample_input.shape}")
print(f"Output shape: {output.logits.shape}")
print(f"Hidden states shape: {output.hidden_states.shape}")

## 5. Data Preparation

Let's create synthetic genomic data for demonstration and learn how to prepare datasets.

In [None]:
# Generate synthetic genomic data
np.random.seed(42)
torch.manual_seed(42)

def generate_synthetic_data(n_samples=1000, seq_length=256):
    """Generate synthetic DNA sequences with labels."""
    # DNA alphabet
    bases = ['A', 'T', 'C', 'G']
    
    sequences = []
    labels = []
    
    for i in range(n_samples):
        # Generate random sequence
        seq = ''.join(np.random.choice(bases, seq_length))
        
        # Simple labeling rules (for demonstration)
        gc_content = (seq.count('G') + seq.count('C')) / len(seq)
        
        if gc_content < 0.3:
            label = 0  # AT-rich (e.g., intergenic)
        elif gc_content < 0.5:
            label = 1  # Moderate GC (e.g., intron)
        elif gc_content < 0.7:
            label = 2  # GC-rich (e.g., exon)
        else:
            label = 3  # Very GC-rich (e.g., promoter)
            
        # Add some pattern-based labels
        if 'TATAAA' in seq:  # TATA box motif
            label = 4  # Promoter
            
        sequences.append(seq)
        labels.append(label)
    
    return sequences, labels

# Generate data
sequences, labels = generate_synthetic_data(1000, 256)

# Create dataset
tokenizer = DNATokenizer()
dataset = GenomicDataset(
    sequences=sequences,
    labels=labels,
    tokenizer=tokenizer,
    max_length=config.max_length
)

print(f"Dataset size: {len(dataset)}")
print(f"Label distribution:")
label_counts = pd.Series(labels).value_counts().sort_index()
for label, count in label_counts.items():
    print(f"  Class {label}: {count} samples")

# Examine a sample
sample = dataset[0]
print(f"\nSample data structure:")
print(f"  Input IDs shape: {sample['input_ids'].shape}")
print(f"  Attention mask shape: {sample['attention_mask'].shape}")
print(f"  Label: {sample['labels']}")

# Visualize GC content distribution
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
gc_contents = [(seq.count('G') + seq.count('C')) / len(seq) for seq in sequences]
plt.hist(gc_contents, bins=30, alpha=0.7, color='skyblue')
plt.xlabel('GC Content')
plt.ylabel('Frequency')
plt.title('Distribution of GC Content')

plt.subplot(1, 2, 2)
label_names = ['AT-rich', 'Moderate GC', 'GC-rich', 'Very GC-rich', 'TATA-box']
plt.bar(range(len(label_counts)), label_counts.values, color='lightcoral')
plt.xlabel('Class')
plt.ylabel('Count')
plt.title('Class Distribution')
plt.xticks(range(len(label_names)), label_names, rotation=45)

plt.tight_layout()
plt.show()

## 6. Training Your First Model

Now let's train a Hyena-GLT model on our synthetic genomic data. We'll use a simple training setup to get started.

In [None]:
from torch.utils.data import DataLoader, random_split

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Configure training
training_config = TrainingConfig(
    num_epochs=10,
    learning_rate=1e-4,
    warmup_steps=100,
    weight_decay=0.01,
    gradient_clip_norm=1.0,
    save_steps=200,
    eval_steps=200,
    logging_steps=50
)

# Initialize trainer
trainer = HyenaGLTTrainer(
    model=model,
    config=training_config,
    train_loader=train_loader,
    eval_loader=val_loader,
    output_dir="./notebook_training_output"
)

print("üöÄ Starting training...")
history = trainer.train()
print("‚úÖ Training completed!")

In [None]:
# Visualize training progress
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(history['train_loss'], label='Train Loss', color='blue')
plt.plot(history['eval_loss'], label='Val Loss', color='red')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.plot(history['train_accuracy'], label='Train Acc', color='blue')
plt.plot(history['eval_accuracy'], label='Val Acc', color='red')
plt.xlabel('Steps')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 3)
plt.plot(history['learning_rate'], color='green')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final metrics
print(f"\nüìä Final Training Results:")
print(f"  Final train loss: {history['train_loss'][-1]:.4f}")
print(f"  Final val loss: {history['eval_loss'][-1]:.4f}")
print(f"  Final train accuracy: {history['train_accuracy'][-1]:.4f}")
print(f"  Final val accuracy: {history['eval_accuracy'][-1]:.4f}")

## 7. Model Evaluation and Analysis

Let's thoroughly evaluate our trained model and understand its performance.

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import itertools

# Evaluate on validation set
model.eval()
val_predictions = []
val_true_labels = []
val_probabilities = []

with torch.no_grad():
    for batch in val_loader:
        outputs = model(batch['input_ids'], attention_mask=batch['attention_mask'])
        predictions = torch.argmax(outputs.logits, dim=-1)
        probabilities = torch.softmax(outputs.logits, dim=-1)
        
        val_predictions.extend(predictions.cpu().numpy())
        val_true_labels.extend(batch['labels'].cpu().numpy())
        val_probabilities.extend(probabilities.cpu().numpy())

# Convert to numpy arrays
val_predictions = np.array(val_predictions)
val_true_labels = np.array(val_true_labels)
val_probabilities = np.array(val_probabilities)

# Class names for better visualization
class_names = ['AT-rich', 'Moderate GC', 'GC-rich', 'Very GC-rich', 'TATA-box']

# Classification report
print("üìã Classification Report:")
print(classification_report(val_true_labels, val_predictions, target_names=class_names))

In [None]:
# Confusion Matrix
cm = confusion_matrix(val_true_labels, val_predictions)

plt.figure(figsize=(15, 5))

# Plot confusion matrix
plt.subplot(1, 3, 1)
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)

# Add text annotations
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, format(cm[i, j], 'd'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.ylabel('True Label')
plt.xlabel('Predicted Label')

# Class-wise accuracy
plt.subplot(1, 3, 2)
class_accuracy = cm.diagonal() / cm.sum(axis=1)
plt.bar(range(len(class_names)), class_accuracy, color='lightgreen')
plt.xlabel('Class')
plt.ylabel('Accuracy')
plt.title('Per-Class Accuracy')
plt.xticks(range(len(class_names)), class_names, rotation=45)
plt.ylim(0, 1)

# Prediction confidence distribution
plt.subplot(1, 3, 3)
max_probs = np.max(val_probabilities, axis=1)
plt.hist(max_probs, bins=20, alpha=0.7, color='orange')
plt.xlabel('Prediction Confidence')
plt.ylabel('Frequency')
plt.title('Prediction Confidence Distribution')
plt.axvline(np.mean(max_probs), color='red', linestyle='--', label=f'Mean: {np.mean(max_probs):.3f}')
plt.legend()

plt.tight_layout()
plt.show()

## 8. Model Interpretation and Analysis

Let's understand what our model has learned and how it makes predictions.

In [None]:
# Analyze model predictions on specific examples
def analyze_prediction(sequence, model, tokenizer, class_names):
    """Analyze model prediction for a specific sequence."""
    # Tokenize
    tokens = tokenizer.encode(sequence)
    input_ids = torch.tensor([tokens])
    
    # Get model output
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, output_attentions=True)
        
    # Get predictions and attention
    logits = outputs.logits[0]
    probabilities = torch.softmax(logits, dim=-1)
    predicted_class = torch.argmax(logits).item()
    
    # Attention weights (average across heads and layers)
    attention_weights = outputs.attentions[-1][0].mean(dim=0)  # Last layer, first sample, average heads
    attention_weights = attention_weights.mean(dim=0)  # Average across query positions
    
    return predicted_class, probabilities.numpy(), attention_weights.numpy()

# Test on some example sequences
test_sequences = [
    "ATATATATATATATATATATATAT",  # AT-rich
    "GCGCGCGCGCGCGCGCGCGCGCGC",  # GC-rich
    "ATCGATCGATCGATCGATCGATCG",  # Balanced
    "TATAAA" + "GCGCGC" * 10,      # TATA box + GC-rich
]

print("üîç Analyzing Example Predictions:")
print("=" * 80)

for i, seq in enumerate(test_sequences):
    pred_class, probs, attention = analyze_prediction(seq, model, tokenizer, class_names)
    
    print(f"\nExample {i+1}: {seq[:30]}...")
    print(f"  GC Content: {(seq.count('G') + seq.count('C')) / len(seq):.3f}")
    print(f"  Predicted: {class_names[pred_class]} (confidence: {probs[pred_class]:.3f})")
    print(f"  All probabilities: {dict(zip(class_names, probs))}")

## 9. Practical Usage: Applying the Model

Now let's see how to use our trained model for practical genomic sequence analysis.

In [None]:
def predict_sequence_type(sequence, model, tokenizer, class_names):
    """Predict the type of a genomic sequence."""
    # Tokenize and prepare input
    tokens = tokenizer.encode(sequence)
    input_ids = torch.tensor([tokens])
    
    # Get prediction
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids)
        probabilities = torch.softmax(outputs.logits, dim=-1)[0]
        predicted_class = torch.argmax(probabilities).item()
    
    return predicted_class, probabilities.numpy()

def batch_predict(sequences, model, tokenizer, class_names, batch_size=32):
    """Predict types for multiple sequences efficiently."""
    results = []
    
    for i in range(0, len(sequences), batch_size):
        batch_seqs = sequences[i:i+batch_size]
        batch_tokens = [tokenizer.encode(seq) for seq in batch_seqs]
        
        # Pad sequences to same length
        max_len = max(len(tokens) for tokens in batch_tokens)
        padded_tokens = []
        attention_masks = []
        
        for tokens in batch_tokens:
            padding_length = max_len - len(tokens)
            padded = tokens + [tokenizer.pad_token_id] * padding_length
            mask = [1] * len(tokens) + [0] * padding_length
            
            padded_tokens.append(padded)
            attention_masks.append(mask)
        
        # Convert to tensors
        input_ids = torch.tensor(padded_tokens)
        attention_mask = torch.tensor(attention_masks)
        
        # Get predictions
        model.eval()
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
            probabilities = torch.softmax(outputs.logits, dim=-1)
            predictions = torch.argmax(probabilities, dim=-1)
        
        # Store results
        for j, (seq, pred, probs) in enumerate(zip(batch_seqs, predictions, probabilities)):
            results.append({
                'sequence': seq,
                'predicted_class': class_names[pred.item()],
                'confidence': probs[pred].item(),
                'all_probabilities': {name: prob.item() for name, prob in zip(class_names, probs)}
            })
    
    return results

# Test the prediction functions
test_sequences = [
    "AAAAAAAAAAAAAAAAAAAAAAAAAAAA",  # Very AT-rich
    "GGGGGGGGGGGGGGGGGGGGGGGGGGGG",  # Very GC-rich
    "ATCGATCGATCGATCGATCGATCGATCG",  # Balanced
    "TATAAAGCGCGCGCGCGCGCGCGCGCGC",  # TATA + GC-rich
    "CGATCGATCGATCGATCGATCGATCGAT",  # Random-looking
]

print("üéØ Testing Prediction Functions:")
print("=" * 60)

for i, seq in enumerate(test_sequences):
    pred_class, probs = predict_sequence_type(seq, model, tokenizer, class_names)
    print(f"Sequence {i+1}: {seq[:20]}...")
    print(f"  Prediction: {class_names[pred_class]} ({probs[pred_class]:.3f} confidence)")
    print()

## 10. Saving and Loading Models

Learn how to save your trained models and load them for future use.

In [None]:
# Save the trained model
model_save_path = "./my_first_hyena_glt_model"
os.makedirs(model_save_path, exist_ok=True)

# Save model and tokenizer
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)

# Save additional metadata
import json
metadata = {
    "model_type": "hyena-glt",
    "task": "dna_classification",
    "classes": class_names,
    "training_samples": len(train_dataset),
    "validation_accuracy": float(history['eval_accuracy'][-1]),
    "sequence_type": "dna",
    "max_length": config.max_length
}

with open(f"{model_save_path}/metadata.json", 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"‚úÖ Model saved to: {model_save_path}")

# Demonstrate loading the model
print("\nüîÑ Loading model from disk...")

# Load model and tokenizer
loaded_model = HyenaGLT.from_pretrained(model_save_path)
loaded_tokenizer = DNATokenizer.from_pretrained(model_save_path)

# Load metadata
with open(f"{model_save_path}/metadata.json", 'r') as f:
    loaded_metadata = json.load(f)

print(f"üìã Loaded model metadata:")
for key, value in loaded_metadata.items():
    print(f"  {key}: {value}")

# Test that loaded model works
test_seq = "ATCGATCGATCGATCGATCGATCG"
original_pred, original_probs = predict_sequence_type(test_seq, model, tokenizer, class_names)
loaded_pred, loaded_probs = predict_sequence_type(test_seq, loaded_model, loaded_tokenizer, class_names)

print(f"\nüß™ Verification Test:")
print(f"  Original model prediction: {class_names[original_pred]} ({original_probs[original_pred]:.4f})")
print(f"  Loaded model prediction: {class_names[loaded_pred]} ({loaded_probs[loaded_pred]:.4f})")
print(f"  Predictions match: {original_pred == loaded_pred}")
print(f"  Probabilities match: {np.allclose(original_probs, loaded_probs, atol=1e-6)}")

## 6. Training Your First Model

Now let's train a Hyena-GLT model on our synthetic data.

In [None]:
# Split data into train/validation
from torch.utils.data import random_split, DataLoader

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Training configuration
training_config = TrainingConfig(
    learning_rate=1e-4,
    batch_size=8,
    num_epochs=3,  # Small number for demo
    warmup_steps=100,
    weight_decay=0.01,
    gradient_clipping=1.0,
    save_steps=100,
    eval_steps=50,
    logging_steps=25
)

print(f"\nTraining configuration:")
print(f"  Learning rate: {training_config.learning_rate}")
print(f"  Batch size: {training_config.batch_size}")
print(f"  Epochs: {training_config.num_epochs}")

In [None]:
# Initialize trainer
trainer = HyenaGLTTrainer(
    model=model,
    config=training_config,
    train_loader=train_loader,
    val_loader=val_loader,
    output_dir="./demo_output"
)

# Train the model
print("üöÄ Starting training...")
training_history = trainer.train()

print("\n‚úÖ Training completed!")
print(f"Final training loss: {training_history['train_loss'][-1]:.4f}")
print(f"Final validation loss: {training_history['val_loss'][-1]:.4f}")
print(f"Final validation accuracy: {training_history['val_accuracy'][-1]:.4f}")

## 7. Model Evaluation

Let's evaluate our trained model using comprehensive genomic metrics.

In [None]:
# Evaluate on validation set
model.eval()
all_predictions = []
all_labels = []
all_probabilities = []

with torch.no_grad():
    for batch in val_loader:
        outputs = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask']
        )
        
        predictions = torch.argmax(outputs.logits, dim=-1)
        probabilities = torch.softmax(outputs.logits, dim=-1)
        
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(batch['labels'].cpu().numpy())
        all_probabilities.extend(probabilities.cpu().numpy())

all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_probabilities = np.array(all_probabilities)

print(f"Evaluation completed on {len(all_predictions)} samples")

In [None]:
# Calculate metrics
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

accuracy = accuracy_score(all_labels, all_predictions)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predictions, average='weighted')

print("üìä Evaluation Metrics:")
print(f"  Accuracy: {accuracy:.4f}")
print(f"  Precision: {precision:.4f}")
print(f"  Recall: {recall:.4f}")
print(f"  F1-score: {f1:.4f}")

# Per-class metrics
class_names = ['AT-rich', 'Moderate GC', 'GC-rich', 'Very GC-rich', 'Promoter']
precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support(
    all_labels, all_predictions, average=None
)

print("\nüìà Per-class Performance:")
for i, class_name in enumerate(class_names):
    if i < len(precision_per_class):
        print(f"  {class_name}:")
        print(f"    Precision: {precision_per_class[i]:.4f}")
        print(f"    Recall: {recall_per_class[i]:.4f}")
        print(f"    F1: {f1_per_class[i]:.4f}")
        print(f"    Support: {support[i]}")

## 8. Visualization and Analysis

Let's create visualizations to better understand our model's performance.

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Training and validation loss
axes[0, 0].plot(training_history['train_loss'], label='Training Loss', color='blue')
axes[0, 0].plot(training_history['val_loss'], label='Validation Loss', color='red')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Validation accuracy
axes[0, 1].plot(training_history['val_accuracy'], label='Validation Accuracy', color='green')
axes[0, 1].set_title('Validation Accuracy')
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1, 0])
axes[1, 0].set_title('Confusion Matrix')
axes[1, 0].set_xlabel('Predicted')
axes[1, 0].set_ylabel('Actual')

# Class distribution
unique_labels, counts = np.unique(all_labels, return_counts=True)
axes[1, 1].bar(class_names[:len(unique_labels)], counts, color='skyblue')
axes[1, 1].set_title('Class Distribution in Validation Set')
axes[1, 1].set_xlabel('Class')
axes[1, 1].set_ylabel('Count')
axes[1, 1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

In [None]:
# Analyze model predictions
# Find examples where the model is most confident
max_probs = np.max(all_probabilities, axis=1)
confidence_threshold = 0.9
high_confidence_mask = max_probs > confidence_threshold

print(f"High confidence predictions (>{confidence_threshold}): {high_confidence_mask.sum()}")
print(f"Accuracy on high confidence predictions: {accuracy_score(all_labels[high_confidence_mask], all_predictions[high_confidence_mask]):.4f}")

# Find examples where the model made mistakes
mistakes = all_predictions != all_labels
print(f"\nNumber of mistakes: {mistakes.sum()}")
print(f"Mistake rate: {mistakes.mean():.4f}")

# Analyze mistakes by class
print("\nMistakes by true class:")
for i, class_name in enumerate(class_names):
    class_mask = all_labels == i
    if class_mask.sum() > 0:
        class_mistakes = mistakes[class_mask].sum()
        class_total = class_mask.sum()
        mistake_rate = class_mistakes / class_total
        print(f"  {class_name}: {class_mistakes}/{class_total} ({mistake_rate:.4f})")

## 9. Model Interpretation

Let's explore what the model has learned by examining attention patterns and feature importance.

In [None]:
# Analyze model's internal representations
model.eval()

# Get a sample for analysis
sample_idx = 0
sample_data = val_dataset[sample_idx]
sample_sequence = sequences[val_dataset.indices[sample_idx]]
sample_label = sample_data['labels'].item()

print(f"Analyzing sample {sample_idx}:")
print(f"Sequence length: {len(sample_sequence)}")
print(f"True label: {class_names[sample_label]}")
print(f"Sequence preview: {sample_sequence[:50]}...")

# Forward pass with attention
with torch.no_grad():
    input_ids = sample_data['input_ids'].unsqueeze(0)
    attention_mask = sample_data['attention_mask'].unsqueeze(0)
    
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_attentions=True
    )
    
    prediction = torch.argmax(outputs.logits, dim=-1).item()
    probabilities = torch.softmax(outputs.logits, dim=-1).squeeze().cpu().numpy()
    
print(f"\nPredicted label: {class_names[prediction]}")
print(f"Prediction confidence: {probabilities[prediction]:.4f}")

print("\nClass probabilities:")
for i, (class_name, prob) in enumerate(zip(class_names, probabilities)):
    print(f"  {class_name}: {prob:.4f}")

## 10. Practical Applications

Let's demonstrate how to apply the trained model to new genomic sequences.

In [None]:
def predict_sequence_type(sequence, model, tokenizer, class_names):
    """Predict the type of a genomic sequence."""
    model.eval()
    
    # Tokenize
    tokens = tokenizer.encode(sequence)
    input_ids = torch.tensor(tokens).unsqueeze(0)
    
    # Create attention mask
    attention_mask = torch.ones_like(input_ids)
    
    # Predict
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        probabilities = torch.softmax(outputs.logits, dim=-1).squeeze().cpu().numpy()
        prediction = np.argmax(probabilities)
    
    return prediction, probabilities

# Test on new sequences
test_sequences = [
    "ATATAAATCGATCGTAGCTAGC",  # Contains TATA box
    "GCGCGCGCGCGCGCGCGCGC",   # Very GC-rich
    "AAAAAATTTTTTAAAAATTTTT", # Very AT-rich
    "ATCGATCGATCGATCGATCG",   # Balanced
]

print("üî¨ Testing on new sequences:")
print("=" * 50)

for i, seq in enumerate(test_sequences):
    prediction, probabilities = predict_sequence_type(
        seq, model, tokenizer, class_names
    )
    
    gc_content = (seq.count('G') + seq.count('C')) / len(seq)
    
    print(f"\nSequence {i+1}: {seq}")
    print(f"GC content: {gc_content:.2f}")
    print(f"Predicted type: {class_names[prediction]}")
    print(f"Confidence: {probabilities[prediction]:.4f}")
    
    # Show top 2 predictions
    top_indices = np.argsort(probabilities)[::-1][:2]
    print("Top predictions:")
    for idx in top_indices:
        print(f"  {class_names[idx]}: {probabilities[idx]:.4f}")

## 11. Model Saving and Loading

Learn how to save and load your trained models for future use.

In [None]:
# Save the model and configuration
save_dir = Path("./saved_models/dna_classifier")
save_dir.mkdir(parents=True, exist_ok=True)

# Save model state
torch.save(model.state_dict(), save_dir / "model.pt")

# Save configuration
config.save_json(save_dir / "config.json")

# Save tokenizer vocabulary
import json
with open(save_dir / "tokenizer_vocab.json", 'w') as f:
    json.dump(tokenizer.vocab, f)

print(f"‚úÖ Model saved to: {save_dir}")

# Demonstrate loading
print("\nüîÑ Loading model...")

# Load configuration
loaded_config = HyenaGLTConfig.from_json(save_dir / "config.json")

# Create new model
loaded_model = HyenaGLT(loaded_config)

# Load weights
loaded_model.load_state_dict(torch.load(save_dir / "model.pt"))

print("‚úÖ Model loaded successfully!")

# Verify loaded model works
test_seq = "ATCGATCGTAGCTAGC"
pred1, prob1 = predict_sequence_type(test_seq, model, tokenizer, class_names)
pred2, prob2 = predict_sequence_type(test_seq, loaded_model, tokenizer, class_names)

print(f"\nVerification:")
print(f"Original model prediction: {class_names[pred1]} ({prob1[pred1]:.4f})")
print(f"Loaded model prediction: {class_names[pred2]} ({prob2[pred2]:.4f})")
print(f"Predictions match: {pred1 == pred2}")
print(f"Probabilities match: {np.allclose(prob1, prob2)}")

## 12. Next Steps and Advanced Topics

Congratulations! You've successfully trained and evaluated your first Hyena-GLT model. Here are some next steps to explore:

### üéØ Immediate Next Steps

1. **Try different genomic tasks**:
   - RNA secondary structure prediction
   - Protein function classification
   - Variant effect prediction

2. **Experiment with model architectures**:
   - Different hidden sizes and layer counts
   - Various Hyena orders
   - Custom positional encodings

3. **Improve data quality**:
   - Use real genomic datasets
   - Implement data augmentation
   - Balance class distributions

### üöÄ Advanced Applications

4. **Multi-task learning**: Train on multiple genomic tasks simultaneously
5. **Transfer learning**: Fine-tune pre-trained models on specific tasks
6. **Model optimization**: Quantization, pruning, and knowledge distillation
7. **Distributed training**: Scale to larger datasets and models

### üìö Additional Resources

- Check out other notebooks in this directory
- Read the comprehensive documentation
- Explore the example scripts
- Join the community discussions

### üõ†Ô∏è Useful Functions for Your Projects

In [None]:
# Utility functions you can use in your own projects

def quick_train(sequences, labels, task_type='classification', epochs=5):
    """Quickly train a model on your data."""
    # Determine sequence type
    sample_seq = sequences[0]
    if set(sample_seq.upper()).issubset({'A', 'T', 'C', 'G'}):
        sequence_type = 'dna'
        tokenizer = DNATokenizer()
    elif set(sample_seq.upper()).issubset({'A', 'U', 'C', 'G'}):
        sequence_type = 'rna'
        tokenizer = RNATokenizer()
    else:
        sequence_type = 'protein'
        tokenizer = ProteinTokenizer()
    
    # Create config
    if task_type == 'classification':
        num_classes = len(set(labels))
        if sequence_type == 'dna':
            config = HyenaGLTConfig.for_dna_classification(num_classes=num_classes)
        elif sequence_type == 'rna':
            config = HyenaGLTConfig.for_rna_structure()
        else:
            config = HyenaGLTConfig.for_protein_function(num_functions=num_classes)
    
    # Create dataset and model
    dataset = GenomicDataset(sequences, labels, tokenizer, config.max_length)
    model = HyenaGLT(config)
    
    # Quick training setup
    train_loader = DataLoader(dataset, batch_size=8, shuffle=True)
    training_config = TrainingConfig(num_epochs=epochs, learning_rate=1e-4)
    
    trainer = HyenaGLTTrainer(
        model=model,
        config=training_config,
        train_loader=train_loader,
        output_dir="./quick_train_output"
    )
    
    # Train
    history = trainer.train()
    
    return model, tokenizer, config, history

def analyze_sequences(sequences, model, tokenizer, class_names):
    """Analyze a list of sequences with a trained model."""
    results = []
    
    for seq in sequences:
        pred, probs = predict_sequence_type(seq, model, tokenizer, class_names)
        
        result = {
            'sequence': seq,
            'length': len(seq),
            'gc_content': (seq.count('G') + seq.count('C')) / len(seq),
            'predicted_class': class_names[pred],
            'confidence': probs[pred],
            'all_probabilities': probs
        }
        results.append(result)
    
    return pd.DataFrame(results)

print("üõ†Ô∏è Utility functions defined!")
print("You can now use:")
print("  - quick_train(sequences, labels) for rapid prototyping")
print("  - analyze_sequences(sequences, model, tokenizer, class_names) for batch analysis")
print("  - predict_sequence_type(sequence, model, tokenizer, class_names) for single predictions")

## Summary

In this notebook, you learned:

‚úÖ **Hyena-GLT Basics**: Understanding the architecture and configuration system  
‚úÖ **Data Processing**: Tokenizing genomic sequences and creating datasets  
‚úÖ **Model Training**: Training a model on synthetic genomic data  
‚úÖ **Evaluation**: Comprehensive performance analysis and visualization  
‚úÖ **Model Interpretation**: Understanding what the model learns  
‚úÖ **Practical Usage**: Applying trained models to new sequences  
‚úÖ **Model Persistence**: Saving and loading models  
‚úÖ **Utility Functions**: Reusable code for your projects  

**What's Next?**

- Explore other notebooks for specific genomic tasks
- Try training on real genomic datasets
- Experiment with different model architectures
- Implement custom genomic tasks

Happy genomic modeling with Hyena-GLT! üß¨üöÄ