# Day 22: Deep Speech 2 — Interactive Walkthrough

> Amodei et al. (2015) — End-to-End Speech Recognition

This notebook walks through the core components of Deep Speech 2:
1. Spectrogram feature extraction
2. The DS2 model architecture
3. CTC loss and greedy decoding
4. Training on synthetic data

See `README.md` for background and `paper_notes.md` for detailed analysis.

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from implementation import (
    DeepSpeech2, CharEncoder, ClippedReLU,
    compute_spectrogram, generate_synthetic_audio,
    generate_synthetic_dataset, collate_batch,
    greedy_decode, word_error_rate, train_step,
    sortagrad_sampler
)

torch.manual_seed(42)
np.random.seed(42)
print('Imports OK')

## 1. Spectrogram Features

The model takes log power spectrograms as input (Section 3.1).
20ms windows with 10ms stride convert a 1D audio signal into a 2D representation.

In [None]:
# Generate synthetic audio for the word 'cat'
audio = generate_synthetic_audio('cat', sample_rate=16000)
spec = compute_spectrogram(audio, sample_rate=16000, n_fft=256)

fig, axes = plt.subplots(1, 2, figsize=(12, 3))
axes[0].plot(audio.numpy(), linewidth=0.5)
axes[0].set_title('Waveform')
axes[0].set_xlabel('Sample')
axes[1].imshow(spec.numpy(), aspect='auto', origin='lower', cmap='viridis')
axes[1].set_title(f'Log Spectrogram: {spec.shape}')
axes[1].set_xlabel('Time')
axes[1].set_ylabel('Frequency')
plt.tight_layout()
plt.show()

## 2. Character Encoding and CTC

The output alphabet is {a-z, space, apostrophe, blank} = 29 symbols.
CTC decoding collapses repeated characters and removes blanks.

In [None]:
encoder = CharEncoder('english')
print(f'Vocabulary size: {encoder.vocab_size}')
print(f'Blank index: {encoder.blank_idx}')

# Encode and decode
text = 'hello world'
encoded = encoder.encode(text)
decoded = encoder.decode(encoded)
print(f'Encode: "{text}" -> {encoded}')
print(f'Decode: {encoded} -> "{decoded}"')

# CTC decoding example
raw_ctc = [0, 0, 8, 8, 0, 5, 5, 5, 0, 12, 12, 0]
ctc_decoded = encoder.ctc_decode(raw_ctc)
print(f'CTC decode: {raw_ctc} -> "{ctc_decoded}"')

## 3. Building the Model

Architecture: Conv2D -> Bidirectional GRU -> FC -> Softmax

In [None]:
# Build a small DS2 model
model = DeepSpeech2(
    n_freq=129,       # for n_fft=256
    vocab_size=encoder.vocab_size,
    n_conv=1,
    n_rnn=2,
    rnn_hidden=128,
    rnn_type='gru'
)

n_params = sum(p.numel() for p in model.parameters())
print(f'Parameters: {n_params:,}')
print(f'Architecture:\n{model}')

## 4. Forward Pass

Run a batch through the model and inspect outputs.

In [None]:
# Generate data and batch
dataset = generate_synthetic_dataset(8, min_words=1, max_words=2)
batch = collate_batch(dataset, encoder, n_fft=256)

print(f'Features: {batch.features.shape}')
print(f'Feature lengths: {batch.feature_lengths.tolist()}')
print(f'Target lengths: {batch.target_lengths.tolist()}')

# Forward pass
model.eval()
with torch.no_grad():
    log_probs, out_lens = model(batch.features, batch.feature_lengths)
    print(f'Output: {log_probs.shape} (time, batch, vocab)')
    
    # Greedy decode (untrained model, so output will be random)
    decoded = greedy_decode(log_probs, encoder, out_lens)
    for i in range(min(4, len(decoded))):
        print(f'  Target: "{dataset[i].transcript}" | Decoded: "{decoded[i]}"')

## 5. Quick Training Loop

Train the model for a few epochs on synthetic data.
Watch the CTC loss decrease and decoded output improve.

In [None]:
# Training setup
train_data = generate_synthetic_dataset(40, min_words=1, max_words=2)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

losses = []
for epoch in range(3):
    model.train()
    indices = sortagrad_sampler(train_data, epoch)
    
    epoch_losses = []
    for start in range(0, len(indices), 8):
        batch_idx = indices[start:start+8]
        samples = [train_data[i] for i in batch_idx]
        batch = collate_batch(samples, encoder, n_fft=256)
        loss = train_step(model, batch, optimizer)
        epoch_losses.append(loss)
        losses.append(loss)
    
    sort = '(sorted)' if epoch == 0 else '(random)'
    print(f'Epoch {epoch+1} {sort}: loss={np.mean(epoch_losses):.4f}')

plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('CTC Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

## 6. Evaluate

Check predictions after training.

In [None]:
model.eval()
with torch.no_grad():
    test_samples = train_data[:5]
    batch = collate_batch(test_samples, encoder, n_fft=256)
    log_probs, out_lens = model(batch.features, batch.feature_lengths)
    decoded = greedy_decode(log_probs, encoder, out_lens)
    
    refs = [s.transcript for s in test_samples]
    wer = word_error_rate(refs, decoded)
    
    print(f'WER: {wer:.2%}')
    print()
    for ref, hyp in zip(refs, decoded):
        match = 'OK' if ref == hyp else 'MISS'
        print(f'  [{match}] Target: "{ref}" | Output: "{hyp}"')

## Summary

- A single neural network (conv + BiGRU + CTC) replaces the entire traditional ASR pipeline
- Spectrogram features: 20ms windows, log power, normalized per utterance
- CTC loss handles variable-length alignment without frame-level labels
- Sequence-wise BatchNorm computes stats over (batch x time), not per-timestep
- SortaGrad sorts by length in epoch 0 only, then reverts to random
- Deeper models (5-7 RNN layers) beat wider ones of similar parameter count
- See `exercises/` to build each component from scratch