# Loss functions for ASR

Automatic speech recognition (ASR) systems convert spoken language into text. Most ASR systems rely heavily on deep learning models that learn from audio data by minimizing a loss function. (If you're not sure what a loss function is, check out [Unit 1.1.2.2](https://docs.google.com/presentation/d/1cX3o6SH1fc4cTzUmhZvyYti557utwgRlgetn6ao_6tg/view#slide=id.g34ea0365488_0_117).)

Designing a loss function for ASR is complicated by the fact that input audio and output texts almost always have different lengths, and are not aligned - i.e. there is no one-to-one correspondence between audio samples and output words/characters. Two common loss functions are used that circumvent these complications.

## 1. Connectionist temporal classification (CTC) loss

CTC loss is designed for sequence-to-sequence tasks where:
- there is a many-to-one mapping of input tokens to output tokens, and
- the relative ordering of pairs of corresponding input and output tokens is preserved.

ASR satisfies both of these conditions, because:
- multiple input audio samples / spectrogram timesteps map to a single word or character, and
- both input audio and output text are ordered by time.

CTC loss is used in several state-of-the-art ASR models, such as DeepSpeech and Wav2Vec.

### How does CTC loss work?

1. The many-to-one mapping from input to output is turned into a one-to-one mapping by allowing the model to predict blank tokens as part of the output.
2. CTC loss is then calculated by summing the probabilities of all possible alignments between the input and target sequences that result in a correct prediction.

A detailed explanation can be found [here](https://distill.pub/2017/ctc/).

### Training a model using CTC loss

PyTorch provides a CTC loss implementation at `torch.nn.CTCLoss`. Let's use this to train a simple LSTM model for ASR.

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim


# Create a simple LSTM model to be trained using CTC loss.
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # LSTM output shape is (batch, sequence, feature)
        outputs, _ = self.lstm(x)
        # Reshape output for the fully connected layer
        outputs = self.fc(outputs)
        # Apply log softmax on the last dimension (num_classes)
        return outputs.log_softmax(dim=-1)


# Initialize the model, loss function, and optimizer
input_size = 13  # input feature dimension, e.g. number of MFCCs
hidden_size = 128 # hidden dimension of LSTM model
num_layers = 2
num_classes = 20 # Including the blank label for CTC

model = SimpleLSTM(input_size, hidden_size, num_layers, num_classes)
ctc_loss = nn.CTCLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Generate some random data
batch_size = 16
sequence_length = 50
inputs = torch.randn(batch_size, sequence_length, input_size)  # (batch, sequence, feature)
input_lengths = torch.full((batch_size,), sequence_length, dtype=torch.long)
targets = torch.randint(1, num_classes, (batch_size, 30), dtype=torch.long)
target_lengths = torch.randint(10, 30, (batch_size,), dtype=torch.long)

# Forward pass: compute predicted outputs by passing inputs to the model
logits = model(inputs)  # (batch, sequence, num_classes)
logits = logits.transpose(0, 1)  # CTC needs input as (sequence, batch, num_classes)

# Calculate loss
loss = ctc_loss(logits, targets, input_lengths, target_lengths)

# Backward pass: compute gradient of the loss with respect to model parameters
loss.backward()

# Perform a single optimization step (parameter update)
optimizer.step()

print("CTC loss:", loss.item())

CTC loss: 5.526454448699951


## 2. Cross-entropy (CE) loss
CE loss is used for classification tasks, where a model predicts a probability distribution over output classes (e.g. a vocabulary of words or characters). CE loss quantifies the "difference" between the predicted and ground-truth probability distributions.

### How does CE loss work for ASR?

Recall that CTC loss aligns input and output sequences by inserting blank tokens in the predicted output.

However, alignment is not necessary if we use CE loss with a **transformer model**. Since transformers simply predict one token for each element of the output sequence, the input sequence length doesn't matter!

Transformer models also give other advantages - for instance:
- Information from previous output tokens can be used to predict the next output token, unlike CTC-based models that assume output tokens are pairwise conditionally independent given the input.
- Transformer models can generate output sequences that are longer than the input. (Though this is rarely necessary for ASR.)

State-of-the-art ASR models trained with CE loss include Whisper, VITA, Ichigo and Moshi.

### Training a model using CE loss

PyTorch provides a CE loss implementation at `torch.nn.CrossEntropyLoss`. Let's use this to train a simple transformer model for ASR.

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim


# Create a simple transformer model to be trained using CE loss.
class SimpleTransformer(nn.Module):
    def __init__(self, num_classes, input_dim, hidden_dim):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.embedding = nn.Embedding(num_classes, hidden_dim)
        self.transformer = nn.Transformer(hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, num_classes)

    def forward(self, inputs, targets):
        transformer_out = self.transformer(
            self.input_proj(inputs), self.embedding(targets),
        )
        return self.output_proj(transformer_out)


# Initialize the model, loss function, and optimizer
input_dim = 16 # input feature dimension, e.g. number of MFCCs
hidden_dim = 512 # hidden dimension of transformer model
num_classes = 10

model = SimpleTransformer(num_classes, input_dim, hidden_dim)
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Generate some random data
batch_size = 16
inputs = torch.randn(batch_size, input_dim)
targets = torch.randint(0, num_classes, (batch_size,))

# Forward pass: compute predicted outputs by passing inputs to the model
logits = model(inputs, targets)

# Calculate loss
# Note that when training an actual transformer, we'd shift the targets by 1 step
# along the time dimension, because we want to predict the *next* token.
loss = ce_loss(logits, targets)

# Backward pass: compute gradient of the loss with respect to model parameters
loss.backward()

# Perform a single optimization step (parameter update)
optimizer.step()

print("CE loss:", loss.item())

CE loss: 2.4389243125915527


## Further resources

### Other loss functions
- [RNNT Loss](https://lorenlugosch.github.io/posts/2020/11/transducer/)
- [Explanation of CTC architectures](https://huggingface.co/learn/audio-course/en/chapter3/ctc)
- [Explanation of Seq2Seq architectures](https://huggingface.co/learn/audio-course/en/chapter3/seq2seq)