# Advanced Architectures in ASR
In this tutorial, we will cover and do various architecture structures used in ASR, and implement simple versions of each architecture.

Do note that this section is highly optional: We recommend you to spend more time finding pretrained models and finetuning on their accuracy instead, as training these models from scratch will require significant data, compute and time.

This section will be for those curious on how the internals work, and how to run inference using newer models.

In [None]:
!pip install transformers
!pip install datasets
!pip install librosa
!pip install soundfile
!pip install evaluate

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (1

In [None]:
# Sample we will use throughout the run
from datasets import load_dataset
from IPython.display import Audio
#1. Stream LibriSpeech “clean” test‑clean split and get first example
ds = load_dataset(
    "librispeech_asr", "clean",
    split="test", streaming=True  # streams one file at a time
)
sample = next(iter(ds))
audio = sample["audio"]["array"]
sr = sample["audio"]["sampling_rate"]

print("Transcript:",sample['text'])

Audio(audio,rate=sr)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/10.2k [00:00<?, ?B/s]

librispeech_asr.py:   0%|          | 0.00/11.4k [00:00<?, ?B/s]

The repository for librispeech_asr contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/librispeech_asr.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y
Transcript: CONCORD RETURNED TO ITS PLACE AMIDST THE TENTS


### Unsupervised ASR Models


Traditional ASR systems require extensive datasets of transcribed speech, limiting their use to more common languages or specific domains rich in audio data

However, with self supervised models, we aim to train models to model an internal understanding of audio based on raw data

This allows the model to "understand" how audio works, but it is unable to do ASR without further tuning

### Wav2Vec and Wav2Vec2

Wav2Vec and Wav2Vec2 are a family of unsupervised ASR models that learn representations from the raw audio data. Unlike other models we have seen so far, the base Wav2Vec and Wav2Vec2 models are unable to actually generate transcripts: They only provide latent audio representations which can then be used downstream by other models (nn.Linear, BERT, etc) in order to provide ASR transcripts.

#### Wait, isn't this just feature extraction?

It is! However the difference is that we are using a neural network to model more complex extraction to represent complex features such as pitch,tone, and other auditory features.


### So how does Wav2Vec actually perform ASR?

Due to the limitations of the Wav2Vec model, we use a downstream decoder or linear head to do the translation of learnt audio representations to words.

Due to this unique design, self supervised ASR models like Wav2Vec have a two stage training pipeline:

1. Pretraining
- Models are trained to create meaningful audio representations from raw audio data. At this stage, only the Wav2Vec Encoder and Context Network are trained.

2. Finetuning
- We can then use this learnt audio representations to do all sorts of downstream audio task. At this stage usually we freeze the encoder and context network, and train a neural network to translate the vectors into the transcripts. This stage is supervised and requires data with transcripts.

However, as the model has an internal understanding of audio already, the amount of data in the second stage is less than that of standard supervised ASR models.

### How does Wav2Vec work?
- The original Wav2Vec consisted of a CNN encoder and a context network
- The CNN encoder would take the raw audio data and turn it into a sequence of latent representations (think of this as compressed bits of data that try to summarize or represent the entire audio data)
- The context network would then aggregate them and output a continuous vector for every 10ms of audio

Below is an approximate implementation of Wav2Vec, with the actual implementation shown [here](https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2_asr.py)



In [None]:
### How does Wav2Vec Work?
# This is just an approximate implementation for those curious
# We REALLY recommend you to visit the github link shown above, which brings
# you to fairseq, which shows the ACTUAL underlying implementation of Wave2Vec

#What the code below provides:
#Starter code to understand the underlying architecture of the Wave2Vec system

#What it does not provide:
# Masking and pretraining logic
# Adapters, EMA, FSDP support
# CTC Heads and Decoder Logic -> Which has been encapsulated by modern HuggingFace renditions of Word2Vec as well.

import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalConv1d(nn.Module):
    """
    1D causal convolution: pads input on the left so that output at time t
    depends only on inputs at ≤ t.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.conv = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation
        )

    def forward(self, x):
        # Compute left padding size for causality
        pad = (self.kernel_size - 1) * self.dilation
        # F.pad expects (pad_left, pad_right)
        x = F.pad(x, (pad, 0))
        return self.conv(x)



class Wave2VecEncoder(nn.Module):
    """
    Encoder network f: X -> Z
    Five-layer causal convolutional network with kernel sizes [10,8,4,4,4]
    and strides [5,4,2,2,2], each followed by GroupNorm (1 group) and ReLU.
    """
    def __init__(self, input_channels=1, conv_channels=512):
        super().__init__()
        kernel_sizes = [10, 8, 4, 4, 4]
        strides = [5, 4, 2, 2, 2]
        layers = []
        in_ch = input_channels
        for k, s in zip(kernel_sizes, strides):
            layers.append(nn.Sequential(
                CausalConv1d(in_ch, conv_channels, kernel_size=k, stride=s),
                nn.GroupNorm(num_groups=1, num_channels=conv_channels),
                nn.ReLU()
            ))
            in_ch = conv_channels
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, sequence_length)
        returns z: Tensor of shape (batch_size, conv_channels, seq_len_downsampled)
        """
        x = x.unsqueeze(1)  # add channel dimension
        z = self.net(x)
        return z

class Wave2VecContext(nn.Module):
    """
    Context network g: Z -> C
    Nine-layer causal convolutional network (kernel size 3, stride 1),
    each followed by GroupNorm (1 group) and ReLU. Receptive field ~210ms.
    """
    def __init__(self, channels=512, num_layers=9):
        super().__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(nn.Sequential(
                CausalConv1d(channels, channels, kernel_size=3, stride=1),
                nn.GroupNorm(num_groups=1, num_channels=channels),
                nn.ReLU()
            ))
        self.net = nn.Sequential(*layers)

    def forward(self, z):
        """
        z: Tensor from encoder (batch_size, channels, seq_len)
        returns c: Tensor of same shape (batch_size, channels, seq_len)
        """
        c = self.net(z)
        return c




In [None]:
class Wave2VecModel(nn.Module):
    """
    Full Wave2Vec model combining encoder and context networks.
    """
    def __init__(self):
        super().__init__()
        self.encoder = Wave2VecEncoder()
        self.context = Wave2VecContext()

    def forward(self, x):
        """
        x: Raw audio (batch_size, sequence_length)
        returns:
          z: Latent features (batch_size, channels, seq_len_downsampled)
          c: Contextualized features (batch_size, channels, seq_len_downsampled)
        """
        z = self.encoder(x)
        c = self.context(z)
        return z, c

# Example usage
batch_size = 2
seq_len = 16000  # 1 second of audio at 16kHz
dummy_audio = torch.randn(batch_size, seq_len)
model = Wave2VecModel()
z, c = model(dummy_audio)
print("Encoder output shape (z):", z.shape)
print("Context output shape (c):", c.shape)



Encoder output shape (z): torch.Size([2, 512, 100])
Context output shape (c): torch.Size([2, 512, 100])


### Wav2Vec 2
Wav2Vec 2 offered multiple breakthroughs in the framework: Instead of only convolutional networks, they added a encoder transformer to help capture information from the entire sequence.

Moreover, they discretize the output of the feature encoder z to a set of a few speech representations, compressing it even further.

#### Discretize? Why do we need it?
Let's first talk about discrete vs continuous:
A continuous vector means an arbitrary vector can be produced as the output of the Wav2Vec model. With discretization, it means that the outputs are fixed to maybe 512, 1024 specific vectors.

This means that the amount of values generated by the model is SIGNIFICANTLY smaller, leading to a much easier job by the model.

Moreover, human language is naturally split into a handful of phoenetic or sub-phonetic sounds like "ah", "t", "s", which are known as phonemes. By making the models learn a fixed number of representations, you give the model a finite inventory, pushing it to optimize the representations to represent useful attributes of the audio like phonemes.

To keep the notebook simple, we will not be adding the implementation of the transformer or Gumbel Softmax here

You can find it [here](https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py)

### Inference using Wav2Vec 2

Notice that we import a Wav2VecCTC in the code below. As mentioned above, the original Wav2Vec fails to do ASR directly as it is just outputting learnt audio representations.

This particular Wav2VecCTC attaches a linear layer with CTC loss onto the architecture, allowing us to run inference for ASR directly

In [None]:
import torch
from datasets import load_dataset
from transformers import (
    Wav2Vec2Processor, Wav2Vec2ForCTC,
    WhisperProcessor, WhisperForConditionalGeneration
)

# 2. Inference with Wav2Vec2
processor_w2v = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model_w2v     = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

inputs = processor_w2v(audio, sampling_rate=sr, return_tensors="pt", padding=True)
with torch.no_grad():
    logits = model_w2v(inputs.input_values).logits
pred_ids = torch.argmax(logits, dim=-1)
text_w2v = processor_w2v.batch_decode(pred_ids)[0]
print("Wav2Vec2 ➞", text_w2v)



preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Wav2Vec2 ➞ CONCORD RETURNED TO ITS PLACE AMIDST THE TENTS


## Supervised ASR Models

Supervised ASR models are trained of large atasets of audio paired with corresponding text transcripts

The labelled data teaches the model to associate speech patterns with specific words and phrases.

High quality and diverse datasets are crucial for supervised ASR model performance.

Due to the audio pairing with transcripts, we can directly train supervised ASR models on ASR tasks, unlike Wav2Vec.


Examples of such models include Whisper, an ASR model designed by OpenAI

### Whisper

Whisper is a model implemented by OpenAI. It is multilingual and can take in multiple language and do two tasks: Either translate the audio to a different language, or heretranscribe the audio in the same language,

It was trained on a massive 680,000 hours of multilingual data, exceeding the scale of most ASR models.

It uses an encoder-decoder transformer architecture, which encodes the audio into learnt representations (similar to Wav2Vec!) but has a decoder built in to translate those representations into words directly.

Contrary to popular belief, although there is a paid API for Whisper, its implementation is open source and can be found [here](https://github.com/openai/whisper)

<img src='https://raw.githubusercontent.com/openai/whisper/main/approach.png'/>

In [None]:
### Whisper Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(1)  # (max_len, 1, d_model)

    def forward(self, x):
        # x shape: (seq_len, batch, d_model)
        x = x + self.pe[:x.size(0)]
        return x

class FeatureEncoder(nn.Module):
  """
  Feature encoder at the start to translate audio into features for the encoder
  """
  def __init__(self, n_mels=80, d_model=256):
      super().__init__()
      self.conv1 = nn.Conv1d(n_mels, d_model, kernel_size=3, padding=1)
      self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
      self.gelu = nn.GELU()

  def forward(self, x):
      # x: (batch, n_mels, time)
      x = self.gelu(self.conv1(x))
      x = self.gelu(self.conv2(x))
      return x  # (batch, d_model, time)

class WhisperLite(nn.Module):
    """
    Whisper Lite model
    Audio gets fed into the feature extractor and then the encoder
    The encoder representations then get fed into the decoder, which transcribes text
    """
    def __init__(self, n_mels=80, d_model=256, n_enc_layers=2, n_dec_layers=2,
                 n_heads=4, vocab_size=1000, dim_ff=512, max_len=500):
        super().__init__()
        self.feature_encoder = FeatureEncoder(n_mels, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)

        encoder_layer = nn.TransformerEncoderLayer(d_model, n_heads, dim_ff)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_enc_layers)

        self.token_emb = nn.Embedding(vocab_size, d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, n_heads, dim_ff)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_dec_layers)

        self.output_proj = nn.Linear(d_model, vocab_size)

    def forward(self, mel_specs, tgt_tokens):
        # mel_specs: (batch, n_mels, time)
        # tgt_tokens: (batch, tgt_len)
        # Encode features
        feats = self.feature_encoder(mel_specs)  # (batch, d_model, time)
        feats = feats.permute(2, 0, 1)           # (time, batch, d_model)
        enc_in = self.pos_enc(feats)            # add positional encoding
        memory = self.encoder(enc_in)           # (time, batch, d_model)

        # Prepare decoder inputs
        tgt_emb = self.token_emb(tgt_tokens).permute(1, 0, 2)  # (tgt_len, batch, d_model)
        tgt_emb = self.pos_enc(tgt_emb)

        # Create masks (no masking for simplicity)
        tgt_mask = None
        memory_mask = None

        # Decode
        dec_out = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
        logits = self.output_proj(dec_out)  # (tgt_len, batch, vocab_size)
        return logits

# Dummy input test
if __name__ == "__main__":
    import math
    batch_size = 1
    n_mels = 80
    time_steps = 100
    tgt_len = 10
    vocab_size = 1000

    # Random mel-spectrogram
    dummy_mel = torch.randn(batch_size, n_mels, time_steps)
    # Random target token sequence
    dummy_tgt = torch.randint(0, vocab_size, (batch_size, tgt_len))

    # Instantiate and run model
    model = WhisperLite(n_mels=n_mels, d_model=256, n_enc_layers=2, n_dec_layers=2,
                        n_heads=4, vocab_size=vocab_size, dim_ff=512, max_len=500)
    output_logits = model(dummy_mel, dummy_tgt)
    print("Output logits shape:", output_logits.shape)  # (tgt_len, batch, vocab_size)



Output logits shape: torch.Size([10, 1, 1000])


### Whisper Inference

Reasons why Whisper might be doing worse than Wav2Vec

1. We use whisper-base, which is a lighter version compared to the larger whisper-v3
1. The specific Wav2Vec model we used was already trained on 960 hours of Librispeech, allowing it to already have an idea of how to transcribe the samples.

In [None]:
# 3. Inference with Whisper
processor_whisper = WhisperProcessor.from_pretrained("openai/whisper-base")  # :contentReference[oaicite:2]{index=2}
model_whisper     = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
# force English transcription
forced_ids = processor_whisper.get_decoder_prompt_ids(language="en", task="transcribe")

inputs = processor_whisper(audio, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
    gen_ids = model_whisper.generate(
        inputs.input_features,
        forced_decoder_ids=forced_ids
    )
text_whisper = processor_whisper.batch_decode(gen_ids, skip_special_tokens=True)[0]
print("Whisper:", text_whisper)


preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.98k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/290M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.81k [00:00<?, ?B/s]

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Whisper:  Concorde returned to its place amidst the tents.


### Further resources:
- Insanely Fast Whisper: https://github.com/Vaibhavs10/insanely-fast-whisper
- WhisperX: https://github.com/m-bain/whisperX
- Finetuning Whisper models: https://huggingface.co/blog/fine-tune-whisper
- Finetuning Wav2Vec BERT: https://huggingface.co/blog/fine-tune-w2v2-bert