# LMU-based ASR System: Complete Demo

This notebook demonstrates the complete workflow for training and using the LMU-based Automatic Speech Recognition system, from setup to inference.

## Overview

1. **Environment Setup**: Configure the environment and imports
2. **Data Preparation**: Load and preprocess LibriSpeech dataset
3. **Model Configuration**: Set up the LMU ASR model
4. **Training**: Train the model with proper logging
5. **Evaluation**: Evaluate model performance
6. **Inference**: Run inference on new audio samples

## Prerequisites

- PyTorch with CUDA support
- All dependencies from requirements.txt
- pytorch-lmu library (located in ../pytorch-lmu/)


## 1. Environment Setup

In [None]:
# Import necessary libraries
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
import json
from omegaconf import DictConfig, OmegaConf

# Set up plotting
%matplotlib inline
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Add src directory to path
sys.path.append('../src')

# Print system information
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Import Project Components

In [None]:
# Import project modules
from config.config import Config, ModelConfig, DataConfig, TrainingConfig, create_config_from_dict
from models.asr_model import create_model, LMUASRModel
from data.dataset import create_dataloaders, HuggingFaceLibriSpeechDataset
from data.preprocessing import AudioPreprocessor, TextPreprocessor, SpecAugment
from training.trainer import Trainer
from training.utils import (
    decode_predictions, decode_targets, compute_wer, compute_cer,
    count_parameters
)

print("✅ All imports successful!")

## 3. Configuration Setup

In [None]:
# Create configuration for demo (smaller model for faster training)
config = Config(
    model=ModelConfig(
        input_size=80,           # Mel spectrogram features
        hidden_size=256,         # Reduced for demo
        memory_size=128,         # Reduced for demo
        num_lmu_layers=2,        # Reduced for demo
        theta=1000.0,
        dropout=0.1,
        use_fft_lmu=False,       # Use standard LMU
        vocab_size=29            # Will be updated based on actual vocab
    ),
    data=DataConfig(
        dataset="librispeech",
        subset="clean",
        sample_rate=16000,
        n_mels=80,
        max_seq_len=800,         # Reduced for demo
        augment=True,
        num_workers=2            # Reduced for demo
    ),
    training=TrainingConfig(
        batch_size=8,            # Reduced for demo
        lr=1e-3,
        max_epochs=3,            # Reduced for demo
        patience=5,
        mixed_precision=True,
        gradient_clip_norm=1.0,
        accumulate_grad_batches=1
    )
)

print("Configuration created:")
print(f"  Model: {config.model.num_lmu_layers} LMU layers, {config.model.hidden_size} hidden units")
print(f"  Data: {config.data.max_seq_len} max sequence length, batch size {config.training.batch_size}")
print(f"  Training: {config.training.max_epochs} epochs, LR {config.training.lr}")

## 4. Data Preparation

In [None]:
# Create data loaders
print("Loading LibriSpeech dataset...")
train_loader, val_loader, vocab = create_dataloaders(config.data, use_huggingface=True)

# Update vocab size in config
config.model.vocab_size = vocab['vocab_size']

print(f"✅ Data loaded successfully!")
print(f"  Training samples: {len(train_loader.dataset)}")
print(f"  Validation samples: {len(val_loader.dataset)}")
print(f"  Vocabulary size: {vocab['vocab_size']}")
print(f"  Batch size: {config.training.batch_size}")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")

### Explore the Data

In [None]:
# Get a sample batch
sample_batch = next(iter(train_loader))
spectrograms, texts, input_lengths, target_lengths = sample_batch

print("Sample batch shapes:")
print(f"  Spectrograms: {spectrograms.shape}")
print(f"  Texts: {texts.shape}")
print(f"  Input lengths: {input_lengths.shape}")
print(f"  Target lengths: {target_lengths.shape}")

# Visualize a spectrogram
plt.figure(figsize=(12, 4))
spec_sample = spectrograms[0, :input_lengths[0], :].numpy()
plt.imshow(spec_sample.T, aspect='auto', origin='lower', cmap='viridis')
plt.colorbar()
plt.title('Sample Mel Spectrogram')
plt.xlabel('Time Steps')
plt.ylabel('Mel Frequency Bins')
plt.tight_layout()
plt.show()

# Show some text samples
print("\nSample texts:")
for i in range(min(3, len(texts))):
    text_indices = texts[i][:target_lengths[i]].tolist()
    decoded_text = ''.join([vocab['idx_to_char'][idx] for idx in text_indices if idx in vocab['idx_to_char']])
    print(f"  {i+1}: '{decoded_text}'")

## 5. Model Creation and Analysis

In [None]:
# Create the model
print("Creating LMU ASR model...")
model = create_model(config.model).to(device)

# Analyze model
total_params, trainable_params = count_parameters(model)
print(f"\nModel created successfully!")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: {total_params * 4 / 1e6:.2f} MB")

# Test forward pass
model.eval()
with torch.no_grad():
    test_input = spectrograms[:2].to(device)
    test_lengths = input_lengths[:2].to(device)
    
    log_probs, memory_states = model(test_input, test_lengths)
    print(f"\nForward pass test:")
    print(f"  Input shape: {test_input.shape}")
    print(f"  Output shape: {log_probs.shape}")
    print(f"  Memory states: {len(memory_states)} layers")
    
    # Test decoding
    predictions = model.decode(log_probs, test_lengths)
    print(f"  Decoded predictions: {len(predictions)} sequences")
    
    # Show sample predictions (before training)
    pred_texts = decode_predictions(predictions, vocab)
    print(f"\nSample predictions (before training):")
    for i, pred in enumerate(pred_texts[:2]):
        print(f"  {i+1}: '{pred}'")

print("\n✅ Model ready for training!")

## 6. Training Setup

In [None]:
# Create trainer
log_dir = './demo_logs'
trainer = Trainer(model, config.training, device, log_dir)

print(f"Trainer created with log directory: {log_dir}")
print(f"Training configuration:")
print(f"  Epochs: {config.training.max_epochs}")
print(f"  Batch size: {config.training.batch_size}")
print(f"  Learning rate: {config.training.lr}")
print(f"  Mixed precision: {config.training.mixed_precision}")
print(f"  Early stopping patience: {config.training.patience}")

## 7. Training

In [None]:
# Train the model
print("Starting training...")
print("Note: This is a demo with reduced epochs and model size.")
print("For production use, increase epochs, model size, and dataset size.")

# Start training
trainer.fit(train_loader, val_loader, vocab)

print("\n✅ Training completed!")

## 8. Training Metrics Visualization

In [None]:
# Load and visualize training metrics
metrics_path = os.path.join(log_dir, 'metrics', 'training_metrics.json')
if os.path.exists(metrics_path):
    with open(metrics_path, 'r') as f:
        metrics = json.load(f)
    
    # Plot training metrics
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Training and validation loss
    if 'train_loss' in metrics and 'val_loss' in metrics:
        axes[0, 0].plot(metrics['train_loss'], label='Training Loss', alpha=0.7)
        axes[0, 0].plot(metrics['val_loss'], label='Validation Loss', alpha=0.7)
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
    
    # Word Error Rate
    if 'val_wer' in metrics:
        axes[0, 1].plot(metrics['val_wer'], label='Validation WER', color='red', alpha=0.7)
        axes[0, 1].set_title('Word Error Rate')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('WER')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
    
    # Character Error Rate
    if 'val_cer' in metrics:
        axes[1, 0].plot(metrics['val_cer'], label='Validation CER', color='green', alpha=0.7)
        axes[1, 0].set_title('Character Error Rate')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('CER')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
    
    # Learning Rate
    if 'learning_rate' in metrics:
        axes[1, 1].plot(metrics['learning_rate'], label='Learning Rate', color='purple', alpha=0.7)
        axes[1, 1].set_title('Learning Rate Schedule')
        axes[1, 1].set_xlabel('Step')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print final metrics
    print("Final Training Metrics:")
    if 'val_loss' in metrics and metrics['val_loss']:
        print(f"  Final Validation Loss: {metrics['val_loss'][-1]:.4f}")
    if 'val_wer' in metrics and metrics['val_wer']:
        print(f"  Final Validation WER: {metrics['val_wer'][-1]:.4f}")
    if 'val_cer' in metrics and metrics['val_cer']:
        print(f"  Final Validation CER: {metrics['val_cer'][-1]:.4f}")
else:
    print("No metrics file found. Training may not have completed.")

## 9. Model Evaluation

In [None]:
# Load best model if available
best_model_path = os.path.join(log_dir, 'checkpoints', 'checkpoint_epoch_1_best.pt')
if os.path.exists(best_model_path):
    print("Loading best model...")
    checkpoint = torch.load(best_model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from epoch {checkpoint['epoch']}")
else:
    print("Using current model state (best model checkpoint not found)")

# Evaluate on validation set
print("\nEvaluating model on validation set...")
val_loss, val_wer, val_cer = trainer.validate(val_loader, vocab)

print(f"\nValidation Results:")
print(f"  Loss: {val_loss:.4f}")
print(f"  WER: {val_wer:.4f} ({val_wer*100:.2f}%)")
print(f"  CER: {val_cer:.4f} ({val_cer*100:.2f}%)")

## 10. Inference Examples

In [None]:
# Run inference on validation samples
model.eval()
num_examples = 5

print(f"Running inference on {num_examples} validation samples...")
print("=" * 80)

with torch.no_grad():
    for i, batch in enumerate(val_loader):
        if i >= 1:  # Only process first batch
            break
            
        spectrograms, texts, input_lengths, target_lengths = batch
        
        # Move to device
        spectrograms = spectrograms.to(device)
        texts = texts.to(device)
        input_lengths = input_lengths.to(device)
        target_lengths = target_lengths.to(device)
        
        # Forward pass
        log_probs, _ = model(spectrograms, input_lengths)
        
        # Decode predictions
        predictions = model.decode(log_probs, input_lengths)
        
        # Convert to text
        pred_texts = decode_predictions(predictions, vocab)
        target_texts = decode_targets(texts, target_lengths, vocab)
        
        # Display results
        for j in range(min(num_examples, len(pred_texts))):
            print(f"Example {j+1}:")
            print(f"  Target:     '{target_texts[j]}'")
            print(f"  Prediction: '{pred_texts[j]}'")
            
            # Calculate individual WER and CER
            if target_texts[j].strip():
                individual_wer = compute_wer([pred_texts[j]], [target_texts[j]])
                individual_cer = compute_cer([pred_texts[j]], [target_texts[j]])
                print(f"  WER: {individual_wer:.4f}, CER: {individual_cer:.4f}")
            print()
            
            if j + 1 >= num_examples:
                break
        
        break

## 11. Custom Inference Function

In [None]:
def transcribe_audio(model, audio_processor, text_processor, audio_path_or_tensor, device):
    """
    Transcribe audio using the trained model.
    
    Args:
        model: Trained ASR model
        audio_processor: AudioPreprocessor instance
        text_processor: TextPreprocessor instance
        audio_path_or_tensor: Path to audio file or audio tensor
        device: Device to run inference on
    
    Returns:
        transcription: Transcribed text
    """
    model.eval()
    
    with torch.no_grad():
        # Process audio
        if isinstance(audio_path_or_tensor, str):
            # Load from file
            features = audio_processor.preprocess_audio(audio_path_or_tensor)
        else:
            # Process tensor
            features = audio_processor.extract_mel_features(audio_path_or_tensor)
        
        # Add batch dimension
        features = features.unsqueeze(0).to(device)
        input_lengths = torch.tensor([features.shape[1]], device=device)
        
        # Forward pass
        log_probs, _ = model(features, input_lengths)
        
        # Decode
        predictions = model.decode(log_probs, input_lengths)
        
        # Convert to text
        vocab_dict = {
            'idx_to_char': text_processor.idx_to_char,
            'char_to_idx': text_processor.char_to_idx,
            'vocab_size': text_processor.get_vocab_size(),
            'blank_token_id': text_processor.get_blank_token_id()
        }
        
        transcription = decode_predictions(predictions, vocab_dict)[0]
        
        return transcription

# Create processors for inference
audio_processor = AudioPreprocessor(
    sample_rate=config.data.sample_rate,
    n_mels=config.data.n_mels,
    normalize=True
)

text_processor = TextPreprocessor()

print("✅ Custom inference function ready!")
print("You can now use transcribe_audio() to transcribe new audio files.")

## 12. Model Analysis and Insights

In [None]:
# Analyze model architecture
print("Model Architecture Analysis:")
print("=" * 50)

# Count parameters by component
encoder_params = sum(p.numel() for p in model.encoder.parameters())
decoder_params = sum(p.numel() for p in model.decoder.parameters())

print(f"Encoder parameters: {encoder_params:,} ({encoder_params/total_params*100:.1f}%)")
print(f"Decoder parameters: {decoder_params:,} ({decoder_params/total_params*100:.1f}%)")

# Analyze LMU layers
print(f"\nLMU Layer Configuration:")
print(f"  Number of layers: {len(model.encoder.lmu_layers)}")
print(f"  Hidden size: {model.encoder.hidden_size}")
print(f"  Memory size: {model.encoder.memory_size}")
print(f"  Theta (timescale): {model.encoder.theta}")
print(f"  Dropout: {model.encoder.dropout}")

# Vocabulary analysis
print(f"\nVocabulary Analysis:")
print(f"  Vocabulary size: {vocab['vocab_size']}")
print(f"  Characters: {list(vocab['char_to_idx'].keys())}")
print(f"  Blank token ID: {vocab['blank_token_id']}")

# Memory usage (if on GPU)
if torch.cuda.is_available():
    print(f"\nGPU Memory Usage:")
    print(f"  Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"  Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

## 13. Saving and Loading Models

In [None]:
# Save final model
final_model_path = os.path.join(log_dir, 'final_model.pt')
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'vocab': vocab,
    'model_params': total_params,
    'final_metrics': {
        'val_loss': val_loss,
        'val_wer': val_wer,
        'val_cer': val_cer
    }
}, final_model_path)

print(f"✅ Final model saved to: {final_model_path}")

# Demonstrate loading
def load_trained_model(model_path, device):
    """Load a trained model for inference."""
    checkpoint = torch.load(model_path, map_location=device)
    
    # Recreate model
    config = checkpoint['config']
    model = create_model(config.model).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    vocab = checkpoint['vocab']
    
    return model, config, vocab

# Test loading
loaded_model, loaded_config, loaded_vocab = load_trained_model(final_model_path, device)
print(f"✅ Model loaded successfully!")
print(f"  Parameters: {sum(p.numel() for p in loaded_model.parameters()):,}")
print(f"  Vocabulary size: {loaded_vocab['vocab_size']}")

## 14. Next Steps and Improvements

### Production Improvements

To improve this demo model for production use:

1. **Scale up the model**:
   - Increase `hidden_size` to 512-1024
   - Increase `memory_size` to 256-512
   - Use 4-8 LMU layers
   - Increase `max_seq_len` to 1000-2000

2. **Extended training**:
   - Train for 50-100 epochs
   - Use larger batch sizes (16-32)
   - Implement learning rate scheduling
   - Use the full LibriSpeech dataset

3. **Advanced techniques**:
   - Language model integration
   - Beam search decoding
   - Sub-word tokenization
   - Multi-task learning

4. **Distributed training**:
   - Use multiple GPUs with `train_distributed.py`
   - Implement gradient accumulation
   - Use advanced optimizers (AdamW, etc.)

5. **Evaluation**:
   - Test on multiple datasets
   - Implement confidence scoring
   - Add real-time inference capabilities

### Usage Examples

```python
# For production training
python scripts/train.py --config-name=base_config

# For distributed training
torchrun --nproc_per_node=4 scripts/train_distributed.py

# For evaluation
python scripts/evaluate.py --checkpoint_path=path/to/checkpoint.pt
```

### Key Takeaways

1. **LMU layers** provide effective temporal modeling for speech
2. **CTC loss** enables alignment-free training
3. **Mixed precision** training reduces memory usage
4. **Proper data augmentation** improves robustness
5. **Distributed training** scales to larger models and datasets

This demo provides a solid foundation for building production-ready ASR systems with LMU architecture!