# ASR I: Homework

In [2]:
# %%capture pip_install_requirements_output
%pip install --quiet --upgrade -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [3]:
import os
import urllib
from collections import defaultdict
from typing import List, Tuple, TypeVar, Optional

import arpa
import numpy as np
import pandas as pd
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchaudio

data_directory = './week_04_data'

In [4]:
%load_ext autoreload
%autoreload 2

import utils

In [5]:
base_url = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?'
public_key = 'https://disk.yandex.ru/d/KqloT1zKr_2VaA'
final_url = base_url + urllib.parse.urlencode(dict(public_key=public_key))
response = requests.get(final_url)
download_url = response.json()['href']
!wget -O week_04_data.tar.gz "{download_url}"
!mkdir -p week_04_data
!tar -xf week_04_data.tar.gz -C week_04_data

--2026-02-17 04:26:49--  https://downloader.disk.yandex.ru/disk/f1f7600b35c478437e766a77c61df11b8b96f2d2fae68457b67509a0597792d6/69942649/M9j18MQZ-8i3RDc98f4DPQz3agkvwYC__XSrv3IqDdcQHAW6Svf8h88_ai8PZuB2XCM1UOIYWQfCrW7gBoYKuw%3D%3D?uid=0&filename=week_04_data.tar.gz&disposition=attachment&hash=z3vgabxT1gNetAOEHhGtJ5CD8KklW6IThKZ1X/QJIM%2BOn1wH6vX4YAgcMTofdqm3q/J6bpmRyOJonT3VoXnDag%3D%3D%3A&limit=0&content_type=application%2Fx-gzip&owner_uid=1031186776&fsize=75981469&hid=10606d3228fe034e4cd3ae29e83fc5c2&media_type=compressed&tknv=v3
Resolving downloader.disk.yandex.ru (downloader.disk.yandex.ru)... 77.88.21.127, 2a02:6b8::2:127
Connecting to downloader.disk.yandex.ru (downloader.disk.yandex.ru)|77.88.21.127|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://s1093sas.storage.yandex.net/rdisk/f1f7600b35c478437e766a77c61df11b8b96f2d2fae68457b67509a0597792d6/69942649/M9j18MQZ-8i3RDc98f4DPQz3agkvwYC__XSrv3IqDdcQHAW6Svf8h88_ai8PZuB2XCM1UOIYWQfCrW7gBoYKuw==?u

### Implementing a Decoder for CTC model (5 points)

Before you can start having fun with a CTC ASR model, you first need to make sure that you can correctly "decode" or generate text from a working model. This can be done in two ways - using a Greedy Decoder, which is simple and fast, or using a Prefix Beam Search decoder, which is slower, but takes advantages of the fact that multiple plath though a CTC trellis can map to the sample sentence. In the following exercise you will implement both decoders.

In [6]:
NEG_INF = utils.NEG_INF
BLANK_SYMBOL = utils.BLANK_SYMBOL

tokenizer = utils.CTCTokenizer()

#### Greedy Best-Path Decoder (1 point)

After we’ve trained the model, we’d like to use it to find a likely output for a given input. Your goal is to implement a Greedy Best-Path decoder. Remember than in CTC the joint distribution over states factors out into a product of marginals:

$${\tt P}(\mathbf{z}_{1:T}|\mathbf{X}_{1:T},\mathbf{\theta}) = \prod_{t = 1}^T{\tt P}(z_t|\mathbf{X}_{1:T},\mathbf{\theta})$$

We can take the most likely output at each time-step, which gives us the alignment with the highest probability:

$$\mathbf{\pi}^*_{1:T} = \arg \max_{\mathbf{\pi}_{1:T} } \prod_{t=1}^T {\tt P}(z_t = \pi_t|\mathbf{X}_{1:T})$$

Then merge repeats and remove blanks.

In [7]:
def greedy_decoder(output: torch.Tensor, labels: List[torch.Tensor],
                   label_lengths: List[int], collapse_repeated: bool = True) -> Tuple[np.ndarray, np.ndarray]:
    """
    :param output: torch.Tensor of Probs or Log-Probs of shape [batch, time, classes]
    :param labels: list of label indices converted to torch.Tensors
    :param label_lengths: list of label lengths (without padding)
    :param collapse_repeated: whether the repeated characters should be deduplicated
    :return: the result of the decoding and the target sequence
    """
    blank_label = tokenizer.get_symbol_index(BLANK_SYMBOL)

    # Get max classes
    ########################
    # YOUR CODE HERE
    arg_maxes = torch.argmax(output, dim=2)
    ########################

    decodes = []
    targets = []

    # For targets and decodes remove repeats and blanks
    for i, args in enumerate(arg_maxes):
        decode = []
        true_labels = labels[i][:label_lengths[i]].tolist()
        targets.append(tokenizer.indices_to_text(true_labels))

        # Remove repeats, then remove blanks
        ########################
        # YOUR CODE HERE
        for j, index in enumerate(args):
            if index != blank_label:
                if j == 0 or index != args[j - 1]:
                    decode.append(index.item())
        ########################

        decodes.append(tokenizer.indices_to_text(decode))
    return decodes, targets

Testing the greedy decoding

In [8]:
# Load numpy matrix, make its shape be in the form of [batch, classes, time]
matrix = np.loadtxt(os.path.join(data_directory, 'test_matrix.txt'))[np.newaxis, :, :]

# Turn into Torch Tensor of shape [batch, time, classes]
matrix = torch.Tensor(matrix).transpose(1, 2)

# Convert indices into torch.Tensor
labels_indices = torch.Tensor(tokenizer.text_to_indices('there seems no good reason for believing that it will change'))

# Run the Decoder
decodes, targets = greedy_decoder(matrix, [labels_indices], [len(labels_indices)])

assert decodes[0] == 'there se ms no good reason for believing that twillc ange'
assert targets[0] == 'there seems no good reason for believing that it will change'

#### Prefix (Beam Search) Decoding With LM (4 points)

The greedy decoder doesn't take into account the fact that a single output can have many alignments. For example, imagine that the true label for a phoneme sequence is $[a]$. Assume that alignments $[a, a, \epsilon]$ and $[a, a, a]$ individually have lower probability than the probability $[b, b, b]$, but the sum of their probabilities is higher. In this case, the greedy decoder would choose the wrong alignment $[b, b, b]$ and propose a wrong hypothesis $[b]$ instead of $[a]$.

Prefix decoding considers probabilities of multiple paths and merges them. It can also add external language model.

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1_X9NfoSe8HLKfAErDtr0rBsIxoejA1kq" height="500px" width="900px">  -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/beam_search.png" height="500px" width="900px"> 

Prefix decoding algorithm has 3 nested loops:
- over time - we extend prefixes up to T times
- over prefixes in the beam
- over possible extensions of a prefix

Each prefix can be extended in three possible ways:
- with a blank
- with a repeating character
- with a non-repeating character

We must keep track of two probabilities per prefix:
- The probability of prefix ending with blank $P_b(t, s)$. 
- The probability of prefix not ending with blank $P_{nb}(t, s)$

Here $t$ denotes time step and $s$ denotes a prefix we got after $t$ time steps.

We start with an empty string prefix: 

$$
    P_b(0, \text{""}) = 1
$$
$$
    P_{nb}(0, \text{""}) = 0
$$

If we extend $s$ with a blank, update the probability of ending with a blank:

$$
    P_b(t, s) = P(\epsilon | x_t) \cdot (P_b(t - 1, s) + P_{nb}(t - 1, s))
$$

The prefix $s$ is not updated because blanks are eliminated in the end.

If we extend with a repeat character $c$, there are two options:
1. The previous symbol is a blank, and now we extend the prefix
2. The previous symbol is not a blank, so we don't extend the prefix (repeats are merged)

In this case, the probability $P_{nb}$ is updated as follows:

$$
    P_{nb}(t, s + c) = P(c | x_t) \cdot P_b(t - 1, s)
$$
$$
    P_{nb}(t, s) = P(c | x_t) \cdot P_{nb}(t - 1, s)
$$

Finally, consider extending $s$ at time $t$ with a non-repeat character. It can follow both blank and non-blank characters, so the probability $P_{nb}$ is updated as follows:

$$
    P_{nb}(t, s + c) = P(c | x_t) \cdot (P_b(t - 1, s) + P_{nb}(t - 1, s))
$$

We may also want to apply a language model during decoding, but only in the case we have a new complete word. This happens when the current symbol is a non-repeat space. As CTC is a discriminative model, LMs can only be integrated as a heuristic:

$$
    \mathbf{w}^* = \arg \max_\mathbf{w} \underbrace{P(\mathbf{w} | \mathbf{X}_{1:T})}_{\text{CTC prob}} \cdot \underbrace{P(\mathbf{w})^{\alpha}}_{\text{LM prob}} \cdot \underbrace{|\mathbf{w}|^\beta}_{\text{Length correction}}
$$

The formula for an update of $P_{nb}$ when LM is used and the current symbol is a non-repeat space:

$$
    P_{nb}(t, s + c) = P_{\text{LM}}(s)^\alpha \cdot |s|^\beta \cdot P(c | x_t) \cdot (P_b(t - 1, s) + P_{nb}(t - 1, s))
$$

In [9]:
LanguageModel = TypeVar("LanguageModel")
# Helper function

class Beam:
    def __init__(self, beam_size: int) -> None:
        self.beam_size = beam_size

        fn = lambda : (NEG_INF, NEG_INF)
        self.candidates = defaultdict(fn)
        self.top_candidates_list = [(tuple(), (0.0, NEG_INF))]

    def get_probs_for_prefix(self, prefix: str) -> Tuple[float, float]:
        p_blank, p_not_blank = self.candidates[prefix]
        return p_blank, p_not_blank

    def update_probs_for_prefix(self, prefix: str, next_p_blank: float, next_p_not_blank: float) -> None:
        self.candidates[prefix] = (next_p_blank, next_p_not_blank)

    def update_top_candidates_list(self) -> None:
        top_candidates = sorted(
            self.candidates.items(),
            key=lambda x: utils.logsumexp(*x[1]),
            reverse=True
        )
        self.top_candidates_list = top_candidates[:self.beam_size]


def calculate_probability_score_with_lm(lm: LanguageModel, prefix: str) -> float:
    text = tokenizer.indices_to_text(prefix).upper().strip()
    lm_prob = lm.log_p(text)
    score = lm_prob / np.log10(np.e) # Correct conversion to ln
    return score

In [10]:
def decode(probs: np.ndarray, beam_size: int = 5, lm: Optional[LanguageModel] = None,
           prune: float = 1e-5, alpha: float = 0.1, beta: float = 2):
    """
    :param probs: A matrix of shape (T, K) with probability distributions over phonemes at each moment of time.
    :param beam_size: the size of beams
    :lm: arpa language model
    :prune: the minimal probability for a symbol at which it can be added to a prefix
    :alpha: the parameter to de-weight the LM probability
    :beta: the parameter to up-weight the length correction term
    :return: the prefix with the highest sum of probabilites P_blank and P_not_blank
    """
    T, S = probs.shape
    probs = np.log(probs)
    blank = tokenizer.get_symbol_index(BLANK_SYMBOL)
    space = tokenizer.get_symbol_index(" ")
    prune = NEG_INF if prune == 0.0 else np.log(prune)

    beam = Beam(beam_size)
    for t in range(T):
        next_beam = Beam(beam_size)

        for s in range(S):
            p = probs[t, s]
            if p < prune:    # Prune the vocab
                continue

            for prefix, (p_blank, p_not_blank) in beam.top_candidates_list:
                if s == blank:
                    p_b, p_nb = next_beam.get_probs_for_prefix(prefix)
                    next_beam.update_probs_for_prefix(
                        prefix=prefix,  # YOUR CODE
                        next_p_blank=utils.logsumexp(p_b, p + utils.logsumexp(p_blank, p_not_blank)),  # YOUR CODE
                        next_p_not_blank=p_nb,  # YOUR CODE
                    )
                    continue

                end_t = prefix[-1] if prefix else None
                n_prefix = prefix + (s,)

                if s == end_t:
                    p_b, p_nb = next_beam.get_probs_for_prefix(n_prefix)
                    next_beam.update_probs_for_prefix(
                        prefix=n_prefix,  # YOUR CODE
                        next_p_blank=p_b,  # YOUR CODE
                        next_p_not_blank=utils.logsumexp(p_nb, p + p_blank),  # YOUR CODE
                    )

                    p_b, p_nb = next_beam.get_probs_for_prefix(prefix)
                    next_beam.update_probs_for_prefix(
                        prefix=prefix,  # YOUR CODE
                        next_p_blank=p_b,  # YOUR CODE
                        next_p_not_blank=utils.logsumexp(p_nb, p + p_not_blank),  # YOUR CODE
                    )
                elif s == space and end_t is not None and lm is not None:
                    p_b, p_nb = next_beam.get_probs_for_prefix(n_prefix)
                    score = calculate_probability_score_with_lm(lm, prefix)
                    length = len(tokenizer.indices_to_text(prefix))
                    
                    lm_correction = (alpha * score + beta * np.log(length)) if length > 0 else 0
                    
                    next_beam.update_probs_for_prefix(
                        prefix=n_prefix,
                        next_p_blank=p_b,
                        next_p_not_blank=utils.logsumexp(p_nb, lm_correction + p + utils.logsumexp(p_blank, p_not_blank)),
                    )
                else:
                    p_b, p_nb = next_beam.get_probs_for_prefix(n_prefix)
                    next_beam.update_probs_for_prefix(
                        prefix=n_prefix,  # YOUR CODE
                        next_p_blank=p_b,  # YOUR CODE
                        next_p_not_blank=utils.logsumexp(p_nb, p + utils.logsumexp(p_blank, p_not_blank)),  # YOUR CODE
                    )

        next_beam.update_top_candidates_list()
        beam = next_beam

    best = beam.top_candidates_list[0]
    return best[0], -utils.logsumexp(*best[1])


def beam_search_decoder(probs: np.ndarray, labels: List[List[int]], label_lengths: List[int],
                        input_lengths: List[int], lm: LanguageModel, beam_size: int = 5,
                        prune: float = 1e-3, alpha: float = 0.1, beta: float = 0.1):
    probs = probs.cpu().detach().numpy()
    decodes, targets = [], []

    for i, prob in enumerate(probs):
        targets.append(tokenizer.indices_to_text(labels[i][:label_lengths[i]].tolist()))
        int_seq, _ = decode(prob[:input_lengths[i]], lm=lm, beam_size=beam_size, prune=prune, alpha=alpha, beta=beta)
        decodes.append(tokenizer.indices_to_text(int_seq))

    return decodes, targets

In [11]:
# Create LM
alm = arpa.loadf(os.path.join(data_directory, '3-gram.pruned.1e-7.arpa'))[0]
alm._unk = '<UNK>'

Testing prefix (beam search) decoding

In [12]:
# Load numpy matrix, add axis [batch, classes, time]
matrix = np.loadtxt(os.path.join(data_directory, 'test_matrix.txt'))[np.newaxis, :, :]

# Turn into Torch Tensor of shape [batch, time, classes]
matrix = torch.Tensor(matrix).transpose(1, 2)

labels_indices = torch.Tensor(tokenizer.text_to_indices('there seems no good reason for believing that it will change'))

# Run the Decoder
decodes, targets = beam_search_decoder(
    matrix, [labels_indices], [len(labels_indices)], [matrix.size()[1]],
    lm=None, beam_size=5, prune=1e-3, alpha=0.1, beta=0.3
)

# assert decodes[0] == 'there se ms no good reason for believing that twillc ange'  # greedy
assert decodes[0] == 'there se ms no good reason for believing that twil c ange'
assert targets[0] == 'there seems no good reason for believing that it will change'

decodes, targets = beam_search_decoder(
    matrix, [labels_indices], [len(labels_indices)], [matrix.size()[1]],
    lm=alm, beam_size=5, prune=1e-3, alpha=0.1, beta=0.3
)

print(f"Actual output with LM: {decodes[0]}")
print(f"Expected output:      there seems no good reason for believing that twil c ange")
assert decodes[0] == 'there seems no good reason for believing that twil c ange'
assert targets[0] == 'there seems no good reason for believing that it will change'

Actual output with LM: there seems no good reason for believing that twil c ange
Expected output:      there seems no good reason for believing that twil c ange


### Examples

- Jasper https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/jasper.html
- DeepSpeech2 https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html
- VGGTransformer https://github.com/facebookresearch/fairseq/blob/main/examples/speech_recognition/models/vggtransformer.py

## RNN-Transducer

### Lecture recap

#### Alignment

Let $\mathbf{x} = (x_1, x_2, \ldots, x_T)$ be a length $T$ input sequence of arbitrary length beloging to the set $X^*$ of all sequences over some input space $X$. Let $\mathbf{y} = (y_1, \ldots, y_U)$ be a length $U$ output sequence belonging to the set $Y^*$ of all sequences over some output space $Y$.

Define the *extended output space* $\overline Y$ as $Y \cup \emptyset$, where $\emptyset$ denotes the null output. The intuitive meaning of $\emptyset$ is 'output nothing'. The sequence $(y_1, \emptyset, \emptyset, y_2, \emptyset, y_3) \in \overline Y^*$ is therefore equivalent to $(y_1, y_2, y_3) \in Y^*$. We refer to the elements $\mathbf{a} \in \overline Y^*$ as *alignments*, since the location of the null symbols determines an alignment between the input and output sequences.

As we saw in CTC, various alignments can be represented in the form of a table called trellis. An example of how an RNN-T trellis may look like:

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1CfXfkePAESz2n20AABVUw9SaZ_xszxwf"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/rnnt_trellis_1.png">
    
    
Possible alignments in that trellis:
    
<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1ipRlSrznwmoD5gCk7k6G06JeUtqPzDQq"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/rnnt_trellis_2.png">
    
The final label can be determined by simply removing the blank characher:
    
$$
    C \emptyset \emptyset A \emptyset T \emptyset \to CAT
$$
$$
    \emptyset \emptyset \emptyset C A T \emptyset \to CAT
$$
    
Given $\mathbf{x}$, the RNN transducer defines a conditional distribution $P(\mathbf{a} \in \overline Y^* | \mathbf{x})$. This distribution is then collapsed onto the following distribution over $Y^*$:
    
$$
    P(\mathbf y \in Y^* | \mathbf x) = \sum_{\mathbf a \in \mathcal{B}^{-1}(\mathbf y)} P(\mathbf a | \mathbf x),
$$
    
where $\mathcal B: \overline Y^* \mapsto Y^*$ is a function that removes the null symbols from the alignments in $Y^*$.


#### Architecture

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1P2aztCi9Z7ookMbHmWBcGtSmG_JHIiMj"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/rnnt_arch.png">

The RNN-T model consists of three neural networks: Encoder, Predictor and Joiner. The Encoder converts the acoustic feature $x_t$ into a high-level representation $f_t$, where $t$ is time index:

$$
    f_t = \mathrm{Encoder}(x_t)
$$

The Predictor works like an RNN language model, which produces a high-level representation $g_u$ by conditioning on the previous non-blank target $y_{u - 1}$ predicted by the RNN-T model, where $u$ is output label index:

$$
    g_u = \mathrm{Predictor}(y_{u - 1})
$$

Note that the input sequence for the predictor **is prepended with the special symbol** $\langle s \rangle$ that defines the start of a sentence.

The Joiner is a feed forward network that combines the Encoder output $f_t$ and the Predictor output $g_u$ as

$$
    h_{t, u} = \mathrm{Joiner}(f_t, g_u) = \mathrm{FeedForward}(\mathrm{ReLU}(f_t + g_u))
$$

The final posterior for each output token $y$ is obtained after applying the softmax operation:

$$
    P(y | t, u) = \mathrm{softmax}(h_{t, u})
$$
    
where $P(y | t, u)$ is a distribution of probabilities to emit $y \in \overline Y$ at time step $t$ after $u$ previously generated characters, $t \in [1, T], u \in [0, U]$.

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1tn1wS3fCVFJGwrYumf5Im6gOFZsxRMV-"> -->
<p style="text-align:center;"><img src="./images/rnnt_probs.png">

We will further need to work with probabilities of individual tokens $y$ for different $t$ and $u$. Instead of writing each time something like $P(y = C | t = 1, u = 0)$, we will, for the sake of simplicity, write it as $P(C | 1, 0)$.

#### Training: forward-backward algorithm

The loss function of RNN-T is the negative log posterior of output label sequence $\mathbf y$ given acoustic feature $\mathbf x$:

$$
    \mathcal L = -\ln P(\mathbf y \in Y^* | \mathbf x) = -\ln \sum_{\mathbf a \in \mathcal{B}^{-1}(\mathbf y)} P(\mathbf a | \mathbf x)
$$

To determine $P(\mathbf a | \mathbf x)$ for an arbitrary alignment $\mathbf a$, we need to multiply the probabilities $P(y | t, u)$ of each symbol across the path:

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1O-aykP5Wods7ZESCJDBsBw2MeBo5egW4"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/rnnt_trellis_probs.png">

$$
    \mathbf a = C \emptyset \emptyset A \emptyset T \emptyset
$$
    
$$
    P(\mathbf a | \mathbf x) = P(C | 1, 0) \cdot P(\emptyset | 1, 0) \cdot P(\emptyset | 2, 1) \cdot P(A | 3, 1) \cdot P(\emptyset | 3, 2) \cdot P(T | 3, 2) \cdot P(\emptyset | 4, 3)
$$

There are usually too many possible alignments to compute the loss function by just adding them all up directly. We will use dynamic programming to make this computation feasible.

Define the *forward variable* $\alpha(t, u)$ as the probability of outputting $\mathbf y_{[1:u]}$ during $\mathbf f_{[1:t]}$. The forward variables for all $1 \le t \le T$ and $0 \le u \le U$ can be calculated recursively using

$$
    \alpha(t, u) = \alpha(t - 1, u) P(\emptyset | t - 1, u) + \alpha(t, u - 1) P(y_{u - 1} | t, u - 1)
$$

with initial condition $\alpha(1, 0) = 1$. Here $y_{u - 1}$ is the $(u - 1)$-th symbol from the ground truth label $\mathbf y$.

The total output sequene probability is equal to the forward variable at the terminal node:

$$
    P(\mathbf y | \mathbf x) = \alpha(T, U) P(\emptyset | T, U)
$$

Define the *backward variable* $\beta(t, u)$ as the probability of outputting $\mathbf y_{[u + 1: U]}$ during $\mathbf f_{[t:T]}$. Then

$$
    \beta(t, u) = \beta(t + 1, u) P(\emptyset | t, u) + \beta(t, u + 1) P(y_u | t, u)
$$

with initial condition $\beta(T, U) = P(\emptyset | T, U)$. The final value is $\beta(1, 0)$.

From the definition of the forward and backward variables it follows that their product $\alpha(t, u) \beta(t, u)$ at any point $(t, u)$ in the output lattice is equal to the probability of emitting the complete output sequence *if $y_u$ is emitted during transcription step $t$*.

### RNN-T Forward-Backward Algorithm (2 points)

Implement forward and backward passes.


#### Implementation tips

- Note that all indices in the arrays you will work with in your code start with zeros. So, the initial condition for forward algorithm will be $\alpha(0, 0) = 1$ (and $\log \alpha(0, 0) = 0$) and the output value for backward algorithm will be $\beta(0, 0)$. The recurrent formulas stay the same. Also, don't be confused with the terminal node: you don't have to add it to $\alpha$- and $\beta$-arrays. The dynamic starts in the upper left corner for forward variables and in the lower right corner for backward variables.
- You will need to do everything in log-domain for calculations to be numercally stable. The function [np.logaddexp](https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html) might help you with it.

In [13]:
def forward(log_probs: torch.FloatTensor, targets: torch.LongTensor,
            blank: int = -1) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    """
    :param log_probs: model outputs after applying log_softmax
    :param targets: the target sequence of tokens, represented as integer indexes
    :param blank: the index of blank symbol
    :return: Tuple[ln alpha, -(ln alpha(T, U) + ln P(blank | T, U))]. The latter term is loss value, which is -ln P(y | x)
    """
    max_T, max_U, D = log_probs.shape

    # here the alpha variable contains logarithm of the alpha variable from the formulas above
    alpha = np.zeros((max_T, max_U), dtype=np.float32)

    for t in range(1, max_T):
        # <YOUR CODE>
        alpha[t, 0] = alpha[t - 1, 0] + log_probs[t - 1, 0, blank]

    for u in range(1, max_U):
        # <YOUR CODE>
        alpha[0, u] = alpha[0, u - 1] + log_probs[0, u - 1, targets[u - 1]]

    for t in range(1, max_T):
        for u in range(1, max_U):
            # <YOUR CODE>
            no_emit = alpha[t - 1, u] + log_probs[t - 1, u, blank]
            emit = alpha[t, u - 1] + log_probs[t, u - 1, targets[u - 1]]
            alpha[t, u] = np.logaddexp(no_emit, emit)

    cost = -(alpha[max_T - 1, max_U - 1] + log_probs[max_T - 1, max_U - 1, blank])  # <YOUR CODE>
    return alpha, cost


def backward(log_probs: torch.FloatTensor, targets: torch.LongTensor,
             blank: int = -1) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    """
    :param log_probs: model outputs after applying log_softmax
    :param targets: the target sequence of tokens, represented as integer indexes
    :param blank: the index of blank symbol
    :return: Tuple[ln beta, -ln beta(0, 0)]. The latter term is loss value, which is -ln P(y | x)
    """
    max_T, max_U, D = log_probs.shape

    # here the beta variable contains logarithm of the beta variable from the formulas above
    beta = np.zeros((max_T, max_U), dtype=np.float32)
    beta[-1, -1] = log_probs[-1, -1, blank]

    for t in reversed(range(max_T - 1)):
        # <YOUR CODE>
        beta[t, max_U - 1] = beta[t + 1, max_U - 1] + log_probs[t, max_U - 1, blank]

    for u in reversed(range(max_U - 1)):
        # <YOUR CODE>
        beta[max_T - 1, u] = beta[max_T - 1, u + 1] + log_probs[max_T - 1, u, targets[u]]

    for t in reversed(range(max_T - 1)):
        for u in reversed(range(max_U - 1)):
            # <YOUR CODE>
            no_emit = beta[t + 1, u] + log_probs[t, u, blank]
            emit = beta[t, u + 1] + log_probs[t, u, targets[u]]
            beta[t, u] = np.logaddexp(no_emit, emit)

    cost = -beta[0, 0]  # <YOUR CODE>
    return beta, cost

In [14]:
def run_test(logits: torch.FloatTensor, targets: torch.LongTensor,
             ref_costs: torch.FloatTensor, blank: int = -1) -> None:
    """
    :param logits: model outputs
    :param targets: the target sequence of tokens, represented as integer indexes
    :param ref_costs: the true values of RNN-T costs for test inputs
    :param blank: the index of blank symbol
    """
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    cost = np.zeros(log_probs.shape[0])

    for batch_id in range(log_probs.shape[0]):
        alphas, cost_alpha = forward(log_probs[batch_id], targets[batch_id], blank=blank)
        betas, cost_beta = backward(log_probs[batch_id], targets[batch_id], blank=blank)
        np.testing.assert_almost_equal(cost_alpha, cost_beta, decimal=2)
        cost[batch_id] = cost_beta

    np.testing.assert_almost_equal(cost, ref_costs, decimal=2)

In [15]:
# Tests

'''
All logits in tests have shapes in the form (B, T, U, D) where

B: batch size
T: maximum source sequence length in batch
U: maximum target sequence length in batch
D: feature dimension of each source sequence element
'''

# test 1
logits = torch.FloatTensor([
    0.1, 0.6, 0.1, 0.1, 0.1,
    0.1, 0.1, 0.6, 0.1, 0.1,
    0.1, 0.1, 0.2, 0.8, 0.1,
    0.1, 0.6, 0.1, 0.1, 0.1,
    0.1, 0.1, 0.2, 0.1, 0.1,
    0.7, 0.1, 0.2, 0.1, 0.1,
]).reshape(1, 2, 3, 5)

targets = torch.LongTensor([[1, 2]])
ref_costs = torch.FloatTensor([5.09566688538])

run_test(
    logits=logits,
    targets=targets,
    ref_costs=ref_costs,
    blank=-1
)

# test 2
logits = torch.FloatTensor([
    0.065357, 0.787530, 0.081592, 0.529716, 0.750675, 0.754135, 0.609764, 0.868140,
    0.622532, 0.668522, 0.858039, 0.164539, 0.989780, 0.944298, 0.603168, 0.946783,
    0.666203, 0.286882, 0.094184, 0.366674, 0.736168, 0.166680, 0.714154, 0.399400,
    0.535982, 0.291821, 0.612642, 0.324241, 0.800764, 0.524106, 0.779195, 0.183314,
    0.113745, 0.240222, 0.339470, 0.134160, 0.505562, 0.051597, 0.640290, 0.430733,
    0.829473, 0.177467, 0.320700, 0.042883, 0.302803, 0.675178, 0.569537, 0.558474,
    0.083132, 0.060165, 0.107958, 0.748615, 0.943918, 0.486356, 0.418199, 0.652408,
    0.024243, 0.134582, 0.366342, 0.295830, 0.923670, 0.689929, 0.741898, 0.250005,
    0.603430, 0.987289, 0.592606, 0.884672, 0.543450, 0.660770, 0.377128, 0.358021,
]).reshape(2, 4, 3, 3)

targets = torch.LongTensor([[1, 2], [1, 1]])
ref_costs = torch.FloatTensor([4.2806528590890736, 3.9384369822503591])

run_test(
    logits=logits,
    targets=targets,
    ref_costs=ref_costs,
    blank=0
)

  alpha[t, u] = np.logaddexp(no_emit, emit)
  beta[t, u] = np.logaddexp(no_emit, emit)


#### Utilities

In [17]:
BOS = utils.BOS
tokenizer = utils.RNNTTokenizer()  # added <BOS> token

In [18]:
# Download LibriSpeech test dataset

if not os.path.isdir("./data"):
    os.makedirs("./data")

test_dataset = torchaudio.datasets.LIBRISPEECH("./data", url="test-clean", download=True)
test_transforms = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=80)

In [19]:
def collator_fn(data, transforms) -> Tuple[torch.Tensor, torch.IntTensor, torch.IntTensor, torch.IntTensor]:
    """
    :param data: a LIBRISPEECH dataset
    :param data_type: "train" or "test"
    :return: tuple of
        spectrograms, shape: (B, T, n_mels)
        labels, shape: (B, U)
        input_lengths -- the length of each spectrogram in the batch, shape: (B,)
        label_lengths -- the length of each text label in the batch, shape: (B,)
        where
        B: batch size
        T: maximum source sequence length in batch
        U: maximum target sequence length in batch
        D: feature dimension of each source sequence element
    """
    spectrograms = []
    labels = []
    input_lengths = []
    label_lengths = []
    for (waveform, _, utterance, _, _, _) in data:
        spec = transforms(waveform).squeeze(0).transpose(0, 1)
        spectrograms.append(spec)
        label = torch.IntTensor(tokenizer.text_to_indices(utterance.lower()))
        labels.append(label)
        input_lengths.append(spec.shape[0])
        label_lengths.append(len(label))

    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

    return spectrograms, torch.IntTensor(labels), torch.IntTensor(input_lengths), torch.IntTensor(label_lengths)

test_collator_fn = lambda data: collator_fn(data, test_transforms)

### Implementing a greedy decoder (2 points)

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1tHsoq0ZH0tHSHYlYlw00y8ksF-wHmrmC"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/rnnt_greedy.png">

Now we know how to train a Transducer, but how do we infer it? Our task is to generate an output sequence $\mathbf y$ given an input acoustic sequence $\mathbf x$.

Here we will index the encoder outputs $f_t$ starting from zero, because it is more convenient when describing an algorithm.

The greedy decoding procedure is as follows:
1. Compute $\{f_0, \ldots, f_T\}$ using $\mathbf x$.
2. Set $t = 0$, $u = 0$, $\mathbf y = []$, $\mathrm{iteration} = 0$.
3. If $u = 0$, set $g_0 = \mathrm{Encoder}(\langle s \rangle)$. If $u > 0$, compute $g_u$ using the last predicted token $\mathbf y[-1]$.
4. Compute $P(y | t, u)$ using $f_t$ and $g_u$.
5. If argmax of $P(y | t, u)$ is a label, set $u = u + 1$ and append the new label to $\mathbf y$. 
6. If argmax of $P(y | t, u)$ is $\emptyset$, set $t = t + 1$.
7. If $t = T$ or $\mathrm{iteration} = \mathrm{max\_iterations}$, we are done. Else, set $\mathrm{iteration} = \mathrm{iteration + 1}$ and go to step 3.

In [20]:
@torch.no_grad()
def greedy_decode(model: 'RNNTransducer', encoder_output: torch.Tensor, max_steps: int = 2000) -> torch.Tensor:
    """
    :param model: an RNN-T model in eval mode
    :param encoder_output: the output of the encoder part of RNN-T, shape: (T, encoder_output_dim)
    :param max_steps: the maximum number of decoding steps
    :return: the predicted labels
    """
    pred_tokens, hidden_state = [], None
    blank = tokenizer.get_symbol_index(BLANK_SYMBOL)
    max_time_steps = encoder_output.size(0)
    t = 0

    decoder_input = encoder_output.new_tensor([[tokenizer.get_symbol_index(BOS)]], dtype=torch.long)
    decoder_output, hidden_state = model.decoder(decoder_input, hidden_states=hidden_state)

    for _ in range(max_steps):
        # <YOUR CODE>
        logits = model.joiner(encoder_output[t].unsqueeze(0), decoder_output)
        out = torch.argmax(logits, dim=-1).squeeze()
        
        if out != blank:
            pred_tokens.append(out.item())
            decoder_input = out.reshape(1, 1)
            decoder_output, hidden_state = model.decoder(decoder_input, hidden_states=hidden_state)
        else:
            t += 1

        if t == max_time_steps:
            break

    return torch.LongTensor(pred_tokens)


@torch.no_grad()
def recognize(model: 'RNNTransducer', inputs: torch.Tensor, input_lengths: torch.Tensor) -> List[torch.Tensor]:
    """
    :param model: an RNN-T model in eval mode
    :param inputs: spectrograms, shape: (B, T, n_mels)
    :param input_lengths: the lengths of the spectrograms in the batch, shape: (B,)
    :return: a list with the predicted labels
    """
    outputs = []
    encoder_outputs, _ = model.encoder(inputs, input_lengths)

    for encoder_output in encoder_outputs:
        decoded_seq = greedy_decode(model, encoder_output)
        outputs.append(decoded_seq)

    return outputs


def get_transducer_predictions(
        transducer: 'RNNTransducer', inputs: torch.Tensor, input_lengths: torch.Tensor,
        targets: torch.Tensor, target_lengths: torch.Tensor
    ) -> pd.DataFrame:
    """
    :param transducer: an RNN-T model in eval mode
    :param inputs: spectrograms, shape: (B, T, n_mels)
    :param input_lengths: the lengths of the spectrograms in the batch, shape: (B,)
    :param targets: labels, shape: (B, U)
    :param target_lengths: the lengths of the text labels in the batch, shape: (B,)
    :return: a pd.DataFrame with inference results
    """
    predictions = recognize(transducer, inputs, input_lengths)
    result = []
    for pred, target, target_len in zip(predictions, targets, target_lengths):
        label = target[:target_len]
        utterance = tokenizer.indices_to_text(list(map(int, label)))
        pred_utterance = tokenizer.indices_to_text(list(map(int, pred)))
        result.append({
            "ground_truth": utterance,
            "prediction": pred_utterance,
            "cer": utils.cer(utterance, pred_utterance),
            "wer": utils.wer(utterance, pred_utterance)
        })
    return pd.DataFrame.from_records(result)


In [21]:
model = torch.jit.load(os.path.join(data_directory, 'model_scripted_epoch_5.pt'))
model.eval()

RecursiveScriptModule(
  original_name=RNNTransducer
  (encoder): RecursiveScriptModule(
    original_name=EncoderRNNT
    (lstm): RecursiveScriptModule(original_name=LSTM)
    (output_proj): RecursiveScriptModule(original_name=Linear)
  )
  (decoder): RecursiveScriptModule(
    original_name=DecoderRNNT
    (embedding): RecursiveScriptModule(original_name=Embedding)
    (lstm): RecursiveScriptModule(original_name=LSTM)
    (output_proj): RecursiveScriptModule(original_name=Linear)
  )
  (joiner): RecursiveScriptModule(
    original_name=Joiner
    (linear): RecursiveScriptModule(original_name=Linear)
  )
)

In [22]:
!sudo apt-get update && sudo apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev

Get:1 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:2 http://archive.ubuntu.com/ubuntu jammy InRelease   
Get:3 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:4 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Get:5 http://security.ubuntu.com/ubuntu jammy-security/restricted amd64 Packages [6,477 kB]
Get:6 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 Packages [4,059 kB]
Get:7 http://security.ubuntu.com/ubuntu jammy-security/multiverse amd64 Packages [62.6 kB]
Get:8 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [3,717 kB]
Get:9 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1,300 kB]
Get:10 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 Packages [1,613 kB]
Get:11 http://archive.ubuntu.com/ubuntu jammy-updates/multiverse amd64 Packages [70.9 kB]
Get:12 http://archive.ubuntu.com/ubuntu jammy-updates/restricted amd64 Packages [6,721 kB]
Get:13 ht

In [26]:
loader = data.DataLoader(test_dataset, batch_size=5, shuffle=False, collate_fn=test_collator_fn)
spectrograms, labels, input_lengths, label_lengths = next(iter(loader))
predictions = get_transducer_predictions(
    model, spectrograms, input_lengths,
    labels, label_lengths
)
predictions

Unnamed: 0,ground_truth,prediction,cer,wer
0,he hoped there would be stew for dinner turnip...,he hoped there would be stew for dinner turnip...,0.132911,0.25
1,stuff it into you his belly counselled him,stuffed into you his belly counciled him,0.142857,0.375
2,after early nightfall the yellow lamps would l...,after early night fall the yellow lamps would ...,0.096154,0.333333
3,hello bertie any good in your mind,her about he and he good in your mind,0.352941,0.714286
4,number ten fresh nelly is waiting on you good ...,none but den fresh now as waiting on you could...,0.254237,0.545455


In [27]:
reference_values = [
    {
        "gt": "he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce",
        "prediction": "he hoped there would be stew for dinner turnips and characts and bruised potatoes and fat much and pieces to be lateled out in the thick peppered flowerfacton sauce"
    },
    {
        "gt": "stuff it into you his belly counselled him",
        "prediction": "stuffed into you his belly counciled him"
    },
    {
        "gt": "after early nightfall the yellow lamps would light up here and there the squalid quarter of the brothels",
        "prediction": "after early night fall the yellow lamps would lie how peer and there the squalit quarter of the brothels"
    },
    {
        "gt": "hello bertie any good in your mind",
        "prediction": "her about he and he good in your mind"
    },
    {
        "gt": "number ten fresh nelly is waiting on you good night husband",
        "prediction": "none but den fresh now as waiting on you could night husband"
    }
]

for index in range(5):
    gt = predictions.iloc[index].ground_truth
    prediction = predictions.iloc[index].prediction
    assert gt == reference_values[index]["gt"]
    assert prediction == reference_values[index]["prediction"]

#### RNN-T module (1 point)

In [28]:
class EncoderRNNT(nn.Module):
    def __init__(self, input_dim: int, hidden_size: int, output_dim: int, n_layers: int,
                 dropout: float = 0.2, bidirectional: bool = True):
        """
        An RNN-based model that encodes input audio features into a hidden representation.
        The architecture is a stack of LSTM's followed by a fully-connected output layer.

        :param input_dim: the number of mel-spectrogram features
        :param hidden_size: the number of features in the hidden states in LSTM layers
        :param output_dim: the output dimension
        :param n_layers: the number of stacked LSTM layers
        :param dropout: the dropout probability for LSTM layers
        :param bidirectional: If True, each LSTM layer becomes bidirectional
        """
        super().__init__()

        self.lstm = nn.LSTM(input_dim, hidden_size, n_layers, batch_first=True, dropout=dropout if n_layers > 1 else 0, bidirectional=bidirectional)  # <YOUR CODE>

        self.output_proj = nn.Linear(hidden_size * 2 if bidirectional else hidden_size, output_dim)  # <YOUR CODE>

    def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        :param inputs: spectrograms, shape: (B, T, n_mels)
        :param input_lengths: the lengths of the spectrograms in the batch, shape: (B,)
        :return: outputs of the projection layer and hidden states from LSTMs
        """
        # <YOUR CODE>
        packed_inputs = nn.utils.rnn.pack_padded_sequence(inputs, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_outputs, hidden = self.lstm(packed_inputs)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        logits = self.output_proj(outputs)

        return logits, hidden

In [29]:
def get_pseudo_batch():
    spectrograms = nn.utils.rnn.pad_sequence([
        torch.rand((835, 80)),
        torch.rand((800, 80))
    ], batch_first=True)
    labels = nn.utils.rnn.pad_sequence([
        torch.randint(len(tokenizer.char_map) - 2, (158,)) + 2,
        torch.randint(len(tokenizer.char_map) - 2, (150,)) + 2
    ], batch_first=True)
    input_lengths = torch.IntTensor([835, 800])
    label_lengths = torch.IntTensor([158, 150])
    return spectrograms, labels, input_lengths, label_lengths

In [30]:
encoder = EncoderRNNT(
    input_dim=80,
    hidden_size=320,
    output_dim=512,
    n_layers=4,
    dropout=0.2,
    bidirectional=True
)

spectrograms, labels, input_lengths, label_lengths = get_pseudo_batch()
logits, hidden_states = encoder.forward(spectrograms, input_lengths)

assert spectrograms.shape == torch.Size([2, 835, 80])
assert logits.shape == torch.Size([2, 835, 512])
assert len(hidden_states) == 2
assert hidden_states[0].shape == torch.Size([8, 2, 320])

In [31]:
class DecoderRNNT(nn.Module):
    def __init__(self, hidden_size: int, vocab_size: int, output_dim: int, n_layers: int, dropout: float = 0.2):
        """
        A simple RNN-based autoregressive language model that takes as input previously generated text tokens
        and outputs a hidden representation of the next token

        :param hidden_size: the number of features in the hidden states in LSTM layers
        :param vocab_size: the number of text tokens in the dictionary
        :param output_dim: the output dimension
        :param n_layers: the number of stacked LSTM layers
        :param dropout: the dropout probability for LSTM layers
        """
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, n_layers, batch_first=True, dropout=dropout if n_layers > 1 else 0)  # <YOUR CODE>
        self.output_proj = nn.Linear(hidden_size, output_dim)  # <YOUR CODE>

    def forward(self, inputs: torch.Tensor, input_lengths: Optional[torch.Tensor] = None,
                hidden_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        :param inputs: labels, shape: (B, U)
        :param input_lengths: the lengths of the text labels in the batch, shape: (B,)
        :return: outputs of the projection layer and hidden states from LSTMs
        """
        embed_inputs = self.embedding(inputs)

        if input_lengths is not None:
            # training phase, the code here is close to `forward` of the Encoder
            # <YOUR CODE>
            packed_inputs = nn.utils.rnn.pack_padded_sequence(embed_inputs, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
            packed_outputs, hidden = self.lstm(packed_inputs)
            outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        else:
            # testing phase
            outputs, hidden = self.lstm(embed_inputs, hidden_states)

        outputs = self.output_proj(outputs)
        return outputs, hidden

In [32]:
decoder = DecoderRNNT(
    hidden_size=512,
    vocab_size=len(tokenizer.char_map),
    output_dim=512,
    n_layers=1,
    dropout=0.2
)

spectrograms, labels, input_lengths, label_lengths = get_pseudo_batch()
logits, hidden_states = decoder.forward(labels, label_lengths)

assert labels.shape == torch.Size([2, 158])
assert logits.shape == torch.Size([2, 158, 512])
assert len(hidden_states) == 2
assert hidden_states[0].shape == torch.Size([1, 2, 512])

In [33]:
class Joiner(nn.Module):
    def __init__(self, joiner_dim: int, num_outputs: int):
        """
        Adds encoder and decoder outputs, applies ReLU and passes the result
        through a fully connected layer to get the output logits

        :param joiner_dim: the dimension of the encoder and decoder outputs
        :num_outputs: the number of text tokens in the dictionary
        """
        super().__init__()
        self.linear = nn.Linear(joiner_dim, num_outputs)

    def forward(self, encoder_outputs: torch.Tensor, decoder_outputs: torch.Tensor) -> torch.Tensor:
        """
        :param encoder_outputs: the encoder outputs (f_t), shape: (B, T, joiner_dim) or (joiner_dim,)
        :param decoder_outputs: the decoder outputs (g_u), shape: (B, U, joiner_dim) or (joiner_dim,)
        :return: output logits
        """
        if encoder_outputs.dim() == 3 and decoder_outputs.dim() == 3:    # True for training phase
            encoder_outputs = encoder_outputs.unsqueeze(2)
            decoder_outputs = decoder_outputs.unsqueeze(1)

        # Linear(ReLU(f_t + g_u))
        out = self.linear(F.relu(encoder_outputs + decoder_outputs))
        return out

In [37]:
class RNNTransducer(torch.nn.Module):
    def __init__(self,
        num_classes: int,
        input_dim: int,
        num_encoder_layers: int = 4,
        num_decoder_layers: int = 1,
        encoder_hidden_state_dim: int = 320,
        decoder_hidden_state_dim: int = 512,
        output_dim: int = 512,
        encoder_is_bidirectional: bool = True,
        encoder_dropout_p: float = 0.2,
        decoder_dropout_p: float = 0.2
    ):
        """
        :param num_classes: the number of text tokens in the dictionary
        :param input_dim: the number of mel-spectrogram features
        :param num_encoder_layers: the number of LSTM layers in the encoder
        :param num_decoder_layers: the number of LSTM layers in the decoder
        :param encoder_hidden_state_dim: the number of features in the hidden states for the encoder
        :param decoder_hidden_state_dim: the number of features in the hidden states for the decoder
        :param output_dim: the output dimension
        :param encoder_is_bidirectional: whether to use bidirectional LSTM's in the encoder
        :param encoder_dropout_p: the dropout probability for the encoder
        :param decoder_dropout_p: the dropout probability for the decoder
        """
        super().__init__()
        self.encoder = EncoderRNNT(input_dim, encoder_hidden_state_dim, output_dim, num_encoder_layers, encoder_dropout_p, encoder_is_bidirectional)  # <YOUR CODE>

        # The decoder takes the input <BOS> + the original sequence.
        # You need to shift the current label, and F.pad can help with that.
        self.decoder = DecoderRNNT(decoder_hidden_state_dim, num_classes, output_dim, num_decoder_layers, decoder_dropout_p)  # <YOUR CODE>
        self.joiner = Joiner(output_dim, num_classes)

    def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor,
            targets: torch.Tensor, target_lengths: torch.Tensor) -> torch.Tensor:
    
        # 1. Encoder output: (B, T, output_dim)
        encoder_outputs, _ = self.encoder(inputs, input_lengths)
    
        # 2. Prepare Decoder Input
        # Prepend BOS token. If targets is (B, 158), decoder_inputs becomes (B, 159)
        bos_token = tokenizer.get_symbol_index(utils.BLANK_SYMBOL) # Or BOS if defined
        decoder_inputs = F.pad(targets, (1, 0), value=bos_token) 
    
        # 3. Decoder output: (B, U+1, output_dim)
        # Note: Ensure decoder can handle the incremented length if it uses target_lengths
        decoder_outputs, _ = self.decoder(decoder_inputs, target_lengths + 1)
    
        # 4. Joiner output: (B, T, U+1, n_tokens)
        joiner_out = self.joiner(encoder_outputs, decoder_outputs)
    
        return joiner_out


In [38]:
transducer = RNNTransducer(
    num_classes=len(tokenizer.char_map),
    input_dim=80,
    num_encoder_layers=4,
    num_decoder_layers=1,
    encoder_hidden_state_dim=320,
    decoder_hidden_state_dim=512,
    output_dim=512,
    encoder_is_bidirectional=True,
    encoder_dropout_p=0.2,
    decoder_dropout_p=0.2
)

spectrograms, labels, input_lengths, label_lengths = get_pseudo_batch()
result = transducer.forward(spectrograms, input_lengths, labels, label_lengths)

assert spectrograms.shape == torch.Size([2, 835, 80])
assert labels.shape == torch.Size([2, 158])
assert result.shape == torch.Size([2, 835, 159, 30])

### Examples

- Nvidia https://huggingface.co/nvidia/parakeet-rnnt-1.1b
- Streaming https://pytorch.org/audio/main/tutorials/online_asr_tutorial.html