In [1]:
!pip install torch torchaudio datasets jiwer

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft

## A simple cross attention transformer

As a recap, we use cross attention transformers in situations where the input and output are of different types.

For example, in machine translation between 2 languages, input English and output Japanese could attend to each other. In ASR, the input audio and the output text could attend to each other.

In self-attention, we capture intra‑sequence dependencies (e.g. which English word helps predict the next English word).

In cross-attention, we capture inter‑sequence dependencies (e.g. which audio frame aligns to this text token).

In the context of ASR,

*   Source: raw audio waveform → converted to mel‑spectrogram frames → embedded and passed through the encoder.
*   Encoder memory: a sequence of vectors representing spectral patterns over time.
* Decoder self‑attention: For example, the decoder (produces text) can learn that the “h” matters more than “t” when deciding what comes next.
* Decoder cross‑attention: each decoder step asks “which audio frames in the spectrogram correspond to the next character or word I should produce?”

In short, cross‑attention learns alignments between sound patterns and textual units.

----

In the example below, we implement three sub‑layers of a Transformer‑decoder layer. (So, it's not just the cross attention itself).
1. Self‑attention over the target sequence
2. Cross‑attention to the encoder’s outputs (the “memory”)
3. Position‑wise feed‑forward network
with residual connections and layer‑norms after each.

For a full decoder (which we did not show below), you would also require
1. A token embedding + positional encoding step before feeding tokens into this block
2. A mask on the self‑attention so each position can only attend to past (or past+present) tokens during training
3. A stack of N such CrossAttentionBlocks (typically 6–12 layers)
4. A final linear projection and softmax to map each position’s output to vocabulary logits

In [7]:
import torch
import torch.nn as nn

# -----------------------------------------
# Cross-Attention Transformer
# -----------------------------------------
# This module demonstrates a minimal cross-attention block,
# showing how a target sequence "attends" to a source sequence.
# We use PyTorch's built-in MultiheadAttention for simplicity.

class CrossAttentionBlock(nn.Module):
    def __init__(self, d_model: int, nhead: int, dropout: float = 0.1):
        """
        Initializes the CrossAttentionBlock.

        Args:
            d_model (int): Dimensionality of input embeddings.
            nhead (int): Number of attention heads.
            dropout (float): Dropout probability.
        """
        super().__init__()
        # 1) Self-attention layer: target attends to itself
        self.self_attn  = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)
        # 2) Cross-attention layer: target attends to source (memory)
        self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)
        # Position-wise feedforward network
        self.linear1    = nn.Linear(d_model, d_model * 4)
        self.linear2    = nn.Linear(d_model * 4, d_model)
        # Layer normalization for residual connections
        self.norm1      = nn.LayerNorm(d_model)
        self.norm2      = nn.LayerNorm(d_model)
        self.norm3      = nn.LayerNorm(d_model)
        # Dropout for regularization
        self.dropout    = nn.Dropout(dropout)

    def forward(self, tgt: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the cross-attention block.

        Args:
            tgt (Tensor): Target embeddings of shape (T, B, D)
                          where T=target sequence length,
                                B=batch size,
                                D=embedding dimension.
            memory (Tensor): Source/memory embeddings of shape (S, B, D)
                              where S=source sequence length.

        Returns:
            Tensor: Output embeddings of shape (T, B, D).
        """
        # ----- 1) Self-Attention on target -----#
        # q, k, v all come from tgt
        #Each target looks at all other targets in the SAME sequence
        tgt2, _ = self.self_attn(tgt, tgt, tgt)
        # Residual connection + normalization
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm1(tgt)

        # ----- 2) Cross-Attention: tgt queries, memory keys/values -----#
        # The query comes from the updated target (after self attention)
        # The key and value comes from the memory / source / the encoder's output
        tgt2, _ = self.cross_attn(query=tgt, key=memory, value=memory)
        # Residual connection + normalization
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm2(tgt)

        # ----- 3) Feed-Forward Network -----#
        # A 2 layer multilayer perceptron (MLP)
        # The hidden layer (linear1) has a ReLU and dropout
        # The second linear (lienar2) reconstructs the matrix back
        # to the original dimensions. If you look at the CrossAttentionBlock class,
        # linear1 takes an input size d_model and the output is of size d_model*4
        # linear2 takes an input size of d_model*4 and has an output size of d_model
        ff = self.linear2(self.dropout(nn.functional.relu(self.linear1(tgt))))
        # Residual + normalization
        tgt = tgt + self.dropout(ff)
        tgt = self.norm3(tgt)

        return tgt


def demo_cross_attention():
    """
    Demonstrates the CrossAttentionBlock with random tensors.
    """
    # Configuration
    batch_size      = 2      # number of examples in a batch
    seq_len_source  = 5      # length of the source (memory) sequence
    seq_len_target  = 3      # length of the target sequence
    d_model         = 16     # embedding size\ n
    nhead           = 4      # number of attention heads

    # Create random source and target sequences
    # PyTorch MultiheadAttention expects shape (seq_len, batch, embed_dim)
    source = torch.randn(seq_len_source, batch_size, d_model)
    target = torch.randn(seq_len_target, batch_size, d_model)

    # Instantiate the cross-attention block
    cross_block = CrossAttentionBlock(d_model=d_model, nhead=nhead)

    # Forward pass: target attends to source
    output = cross_block(target, source)

    print(f"Input target shape: {target.shape}")
    print(f"Input source shape: {source.shape}")
    print(f"Output embeddings shape: {output.shape}")

Now, let's run demo_cross_attention(). In `source`, it generates 5 “memory” positions, batch size 2, hidden dim 16. In `target`, it generates 3 “query” positions, batch size 2, hidden dim 16.

We then run cross attention. It self-atttends over the 3 target vectors, cross‑attends each of those 3 positions over the 5 source vectors, then passes the result through the feed‑forward network.

We want to see that the output has the same shape as the target.

In [3]:
demo_cross_attention()

Input target shape: torch.Size([3, 2, 16])
Input source shape: torch.Size([5, 2, 16])
Output embeddings shape: torch.Size([3, 2, 16])


## A cross attention transformer in the context of ASR, and a simple training run.

In the following code, we use pytorch's built in TransformerEncoder and TransformerDecoder method that already implements cross attention.
Just like above, we do the following:
1. Self attention on the encoder input AND the decoder input
2. Cross-attention against the encoder memory (which refers to the encoder output)
3. A position-wise feed forward neural network (position-wise refers to the idea of not mixing information at each index of the matrix, unlike cross-attention which gives cross-token context. It is a traditional MLP which does non-linear transformation per-position.)
4. Residual connections and layer norm at every step.




In [4]:
!pip install soundfile librosa



In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from torchaudio.transforms import MelSpectrogram
from jiwer import wer

# -----------------------------------------------------------------------------
# 1) Data preparation with HuggingFace datasets
# -----------------------------------------------------------------------------
ds = load_dataset(
    "hf-internal-testing/librispeech_asr_demo",
    "clean",
    split="validation",
    trust_remote_code=True
)

# Do an 80/20 split for train/val
splits = ds.train_test_split(test_size=0.2, seed=42)
train_ds, val_ds = splits["train"], splits["test"]

# Character tokenizer
chars   = list("abcdefghijklmnopqrstuvwxyz' ")
char2idx = {c: i+1 for i,c in enumerate(chars)}  # 0 reserved for padding
vocab_size = len(char2idx) + 1

mel_transform = MelSpectrogram(sample_rate=16_000, n_mels=128)

def collate_batch(batch):
    specs, labels = [], []
    for ex in batch:
        # Make sure we load as float32, not the default float64
        # MelSpectrogram expects float32
        waveform = torch.tensor(ex["audio"]["array"], dtype=torch.float32)

        # Now the mel transform will run without complaint
        spec = mel_transform(waveform.unsqueeze(0))      # -> (1, n_mels, T)
        spec = spec.squeeze(0).transpose(0,1)            # -> (T, n_mels)
        specs.append(spec)

        # Text → token IDs (same as before)
        ids = torch.tensor([char2idx.get(c,0) for c in ex["text"].lower()],
                           dtype=torch.long)
        labels.append(ids)

    # Pad and batch
    specs = nn.utils.rnn.pad_sequence(specs, batch_first=True)   # (B, T, n_mels)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) # (B, L)
    return specs, labels

train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, \
                      collate_fn=collate_batch)
val_dl   = DataLoader(val_ds,   batch_size=4, shuffle=False, \
                      collate_fn=collate_batch)


# -----------------------------------------------------------------------------
# 2) Build a tiny encoder–decoder ASR model
# -----------------------------------------------------------------------------
class ASRModel(nn.Module):
    def __init__(self, d_model=256, nhead=8, num_layers=3):
        super().__init__()
        # audio encoder: project 128→d_model, then self-attend
        self.audio_proj = nn.Linear(128, d_model)
        enc_layer = nn.TransformerEncoderLayer(d_model, nhead) #8 head multi-head self attention
        self.encoder  = nn.TransformerEncoder(enc_layer, num_layers)

        # text decoder: embed + cross‑attention blocks
        self.text_emb  = nn.Embedding(vocab_size, d_model)
        dec_layer      = nn.TransformerDecoderLayer(d_model, nhead) #8 head multi-head self attention
        self.decoder   = nn.TransformerDecoder(dec_layer, num_layers)

        # final output projection
        self.out_proj  = nn.Linear(d_model, vocab_size)

    def forward(self, specs, tokens_in):
        """
        specs: (B, T_src, 128)
        tokens_in: (B, T_tgt)  — teacher‑forcing inputs
        """
        # Encode audio
        x = self.audio_proj(specs)               # (B, T_src, d_model)
        x = x.permute(1,0,2)                     # (T_src, B, d_model)
        memory = self.encoder(x)                 # same shape

        # Prepare decoder input
        y = self.text_emb(tokens_in)             # (B, T_tgt, d_model)
        y = y.permute(1,0,2)                     # (T_tgt, B, d_model)
        out = self.decoder(y, memory)            # (T_tgt, B, d_model)
        out = out.permute(1,0,2)                 # (B, T_tgt, d_model)

        logits = self.out_proj(out)              # (B, T_tgt, vocab_size)
        return logits


# -----------------------------------------------------------------------------
# 3) Training loop
# -----------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ASRModel().to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=0)

def train_epoch(dl):
    model.train()
    total_loss = 0
    for specs, labels in dl:
        specs, labels = specs.to(device), labels.to(device)
        # decoder input: all but last token
        inp = labels[:, :-1]
        tgt = labels[:, 1:]
        logits = model(specs, inp)            # (B, T-1, V)
        loss = criterion(logits.reshape(-1, logits.size(-1)),
                         tgt.reshape(-1))
        optim.zero_grad(); loss.backward(); optim.step()
        total_loss += loss.item()
    return total_loss / len(dl)

def eval_epoch(dl):
    model.eval()
    total_loss = 0
    preds, refs = [], []
    with torch.no_grad():
        for specs, labels in dl:
            specs, labels = specs.to(device), labels.to(device)
            inp = labels[:, :-1]
            tgt = labels[:, 1:]
            logits = model(specs, inp)
            loss = criterion(logits.reshape(-1, logits.size(-1)),
                             tgt.reshape(-1))
            total_loss += loss.item()

            # greedy decode
            out_ids = logits.argmax(-1).cpu().tolist()
            ref_ids = tgt.cpu().tolist()
            for o, r in zip(out_ids, ref_ids):
                preds.append("".join(chars[i-1] for i in o if i>0))
                refs.append("".join(chars[i-1] for i in r if i>0))

    avg_loss = total_loss / len(dl)
    avg_wer  = wer(refs, preds)
    return avg_loss, avg_wer

# run 20 epochs for demo
# You can see the train_loss and val_loss decreasing
# FOR WARREN: IDK WHY THE val_WER not decreasing
for epoch in range(1,21):
    train_loss = train_epoch(train_dl)
    val_loss, val_wer = eval_epoch(val_dl)
    print(f"Epoch {epoch} ▶ train_loss={train_loss:.3f}  val_loss={val_loss:.3f}  val_WER={val_wer:.3f}")

Epoch 1 ▶ train_loss=2.931  val_loss=2.757  val_WER=1.412
Epoch 2 ▶ train_loss=2.626  val_loss=2.585  val_WER=1.255
Epoch 3 ▶ train_loss=2.470  val_loss=2.518  val_WER=1.278
Epoch 4 ▶ train_loss=2.403  val_loss=2.482  val_WER=1.176
Epoch 5 ▶ train_loss=2.380  val_loss=2.459  val_WER=1.188
Epoch 6 ▶ train_loss=2.339  val_loss=2.456  val_WER=1.212
Epoch 7 ▶ train_loss=2.318  val_loss=2.448  val_WER=1.251
Epoch 8 ▶ train_loss=2.303  val_loss=2.452  val_WER=1.145
Epoch 9 ▶ train_loss=2.284  val_loss=2.445  val_WER=1.224
Epoch 10 ▶ train_loss=2.260  val_loss=2.444  val_WER=1.161
Epoch 11 ▶ train_loss=2.236  val_loss=2.448  val_WER=1.235
Epoch 12 ▶ train_loss=2.219  val_loss=2.448  val_WER=1.200
Epoch 13 ▶ train_loss=2.204  val_loss=2.451  val_WER=1.192
Epoch 14 ▶ train_loss=2.188  val_loss=2.463  val_WER=1.231
Epoch 15 ▶ train_loss=2.158  val_loss=2.470  val_WER=1.204
Epoch 16 ▶ train_loss=2.141  val_loss=2.474  val_WER=1.204
Epoch 17 ▶ train_loss=2.117  val_loss=2.478  val_WER=1.329
Epoch 