**Description**

Implement CTC as this paper describes. Your implementation should support both forward and 
backward propagation operations. 


### Import Packages

In [58]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchaudio

import numpy as np

import editdistance

import os
from torch.utils.data import Dataset

import soundfile as sf

## Data Preparation

In [87]:
class TIMITDataset(Dataset):
    def __init__(self, root_dir, subset='TRAIN', transform=None):
        """
        Args:
            root_dir (str): Root directory of the TIMIT dataset.
            subset (str): 'TRAIN' or 'TEST'.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = os.path.join(root_dir, subset)
        self.transform = transform

        # Collect all file paths
        self.items = []
        for root, _, files in os.walk(self.root_dir):
            for file in files:
                if file.endswith(".WAV"):  # case-insensitive check
                    wav_path = os.path.join(root, file)
                    
                    # Expected parallel paths
                    phn_path = wav_path.replace(".WAV", ".PHN")
                    txt_path = wav_path.replace(".WAV", ".TXT")
                    
                    self.items.append((wav_path, phn_path, txt_path))
                   
            
    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        wav_path, phn_path, txt_path = self.items[idx]

        # === Load audio using soundfile ===
        try:
            waveform, sample_rate = sf.read(wav_path)
        except Exception as e:
            print(f"Error loading {wav_path}: {e}")
            return None

        # === Preprocess: Frame, MFCC, Log-Energy, Delta Coefficients ===
        features = self.preprocess(waveform, sample_rate)

        # === Load phoneme labels ===
        phonemes = []
        with open(phn_path, 'r') as f:
            for line in f:
                start, end, label = line.strip().split()
                phonemes.append((int(start), int(end), label))

        # === Load text transcription ===
        with open(txt_path, 'r') as f:
            transcript = f.readline().strip()

        # === Create the sample dictionary ===
        sample = {
            'features': features,            # <--- New: Preprocessed Features
            'phonemes': phonemes,
            'transcript': transcript,
            'wav_path': wav_path
        }

        if self.transform:
            sample = self.transform(sample)
        
        return sample
    
    def preprocess(self, waveform, sample_rate):
        """
        Preprocesses the audio:
        1. Frame into 10ms windows with 5ms overlap.
        2. Extract 12 MFCC coefficients.
        3. Compute log-energy.
        4. Compute first derivatives (deltas).
        5. Normalize.
        """
        # === Frame the signal ===
        frame_length = int(0.01 * sample_rate)  # 10 ms
        frame_step = int(0.005 * sample_rate)   # 5 ms

        # === Compute MFCCs ===
        mfcc = torchaudio.transforms.MFCC(
            sample_rate=sample_rate,
            n_mfcc=12,
            melkwargs={'n_fft': 400, 'hop_length': frame_step, 'n_mels': 26}
        )(torch.tensor(waveform, dtype=torch.float32).unsqueeze(0))

        # === Compute log-energy ===
        log_energy = torch.log(torch.clamp(mfcc.pow(2).sum(dim=1), min=1e-10)).unsqueeze(0)

        # === Compute deltas (first derivatives) ===
        delta = torchaudio.functional.compute_deltas(torch.cat([mfcc, log_energy], dim=1))

        # === Concatenate MFCC + log-energy + delta ===
        features = torch.cat([mfcc, log_energy, delta], dim=1)  # Shape: (1, 26, T)

        # === Normalize each feature dimension to mean 0 and std 1 ===
        mean = features.mean(dim=2, keepdim=True)
        std = features.std(dim=2, keepdim=True)
        features = (features - mean) / (std + 1e-10)

        return features.squeeze(0).transpose(0, 1)  # Final shape: (T, 26)

In [88]:
# Initialize the dataset
train_dataset = TIMITDataset(root_dir='TIMIT_dataset/data', subset='TRAIN/DR1/FCJF0')

# Check the number of samples
print(f"Number of samples: {len(train_dataset)}")

# Get the first sample
sample = train_dataset[0]
print("Feature shape:", sample['features'].shape)  # (T, 26)
print("Phonemes:", sample['phonemes'])
print("Transcript:", sample['transcript'])

Number of samples: 10
Feature shape: torch.Size([290, 26])
Phonemes: [(0, 1960, 'h#'), (1960, 2170, 'dh'), (2170, 2616, 'ix'), (2616, 3905, 'm'), (3905, 5639, 'iy'), (5639, 6182, 'dx'), (6182, 7400, 'iy'), (7400, 8293, 'ng'), (8293, 9364, 'ih'), (9364, 10160, 'z'), (10160, 10960, 'epi'), (10960, 11198, 'n'), (11198, 13707, 'aw'), (13707, 14400, 'ix'), (14400, 15200, 'dcl'), (15200, 16072, 'jh'), (16072, 18800, 'er'), (18800, 20137, 'n'), (20137, 20490, 'dcl'), (20490, 20887, 'd'), (20887, 23040, 'h#')]
Transcript: 0 23143 The meeting is now adjourned.


In [71]:
class BLSTM_CTC(nn.Module):
    def __init__(self, input_dim=26, hidden_dim=100, output_dim=62, num_layers=1):
        super(BLSTM_CTC, self).__init__()
        
        # Bidirectional LSTM
        self.blstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True
        )
        
        # Fully connected layer maps 2*hidden_dim (for bidirection) to output_dim
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        
        # Activation function
        self.log_softmax = nn.LogSoftmax(dim=2)

    def forward(self, x):
        """
        Forward pass for the BLSTM-CTC network.

        Args:
            x (Tensor): Input tensor of shape (batch_size, time_steps, input_dim)

        Returns:
            Tensor: Log-softmax probabilities of shape (batch_size, time_steps, output_dim) 
        """
        # BLSTM forward
        x, _ = self.blstm(x)  # x shape: (batch, time_steps, hidden_dim * 2)

        # Fully connected to output layer
        x = self.fc(x)  # shape: (batch, time_steps, output_dim)

        # Apply log softmax for CTC Loss compatibility
        x = self.log_softmax(x)

        return x


## Computing $p(l | x)$

### Forward variables

$$ \alpha_t(s) := \sum_{\pi \in L'^T \text{ s.t. } \mathcal{B}(\pi_{1:t} = l_{1:s})} \ \prod_{t' = 1}^t \ y_{\pi_{t'}}^{t'}$$

The forward variable $\alpha_t(s)$ represents the probability of correctly aligning the prefix of the modified label sequence $l'$ (up to position $s$) with the input sequence (up to time step $t$).

#### Initialization for Dynamic Programming

In [None]:
def compute_forward_values(T, target_sequence, probs):
    """
    Compute the forward values for the CTC loss using dynamic programming.

    Args:
        T (int): The number of time steps in the input sequence.
        target_sequence (list): The target sequence of labels.
        probs (np.ndarray): probs[t, s] is the probability of the t-th time step being the s-th label.

    Returns:
        alpha: The forward values of shape (T, len(target_sequence)).
        C: The vector used to normalize the values of alpha
    """

    len_target = len(target_sequence)

    # Initialize alpha and C
    alpha = np.zeros((T, len_target))
    C = np.zeros(T)

    # Base cases
    alpha[0,0] = probs[0,0] # label 0 is blank
    alpha[0,1] = probs[0,target_sequence[0]]

    # Recursion
    for t in range(1, T):
        for s in range(len_target): 
            alpha[t,s] += alpha[t-1, s]
            if s >= 1:
                alpha[t,s] += alpha[t-1, s-1]
            if s >= 2 and target_sequence[s] != target_sequence[s-2]:
                alpha[t,s] += alpha[t-1, s-2]
            alpha[t,s] *= probs[t, target_sequence[s]]
            C[t] += alpha[t,s]
        
        # Rescale row
        if C[t] != 0:
            alpha[t,:] /= C[t]

    return alpha, C

In [None]:
def compute_backward_values(T, target_sequence, probs):
    """
    Compute the backward values for the CTC loss using dynamic programming.

    Args:
        T (int): The number of time steps in the input sequence.
        target_sequence (list): The target sequence of labels.
        probs (np.ndarray): probs[t, s] is the probability of the t-th time step being the s-th label.

    Returns:
        np.ndarray: The backward values of shape (T, len(target_sequence)).
    """

    len_target = len(target_sequence)

    # Initialize beta matrix
    beta = np.zeros((T, len_target))
    D = np.zeros(T)

    # Base cases
    beta[T,len_target] = probs[T,0] # label 0 is blank
    beta[T,len_target-1] = probs[0,target_sequence[len_target]]

    # Recursion
    for t in range(T-1, -1, -1):
        for s in range(len_target -1, -1, -1): 
            beta[t,s] += beta[t+1, s]
            if s <= len_target - 2:
                beta[t,s] += beta[t+1, s+1]
            if s <= len_target - 3 and target_sequence[s] != target_sequence[s+2]:
                beta[t,s] += beta[t+1, s+2]
            beta[t,s] *= probs[t, target_sequence[s]]
            D[t] += beta[t,s]

        # Rescale row
        if D[t] != 0:
            beta[t,:] /= D[t]

    return beta, D

In [None]:
def compute_total_probability(alpha):
    """
    Compute the total path probability from the forward variables (alpha).

    Args:
        alpha (numpy.ndarray or torch.Tensor): Forward variables of shape (T, S)

    Returns:
        float: Total path probability P(y|x)
    """
    _, S = alpha.shape  # Get the number of time steps and states

    # Sum of probabilities from the last time step (T-1) at the final two states (S-1 and S-2)
    total_prob = alpha[-1, S - 1] + alpha[-1, S - 2]
    return total_prob

In [None]:
def lab(sequence, k):
    """
    Returns the set of indices of the sequence where the value is equal to k.

    Args:
        sequence (list): The input sequence of labels.
        k (int): The value to search for in the sequence.

    Returns:
        list: A list of indices where the value is equal to k.
    """
    indices = []
    for i, value in enumerate(sequence):
        if value == k:
            indices.append(i)
    return indices

In [None]:
def ctc_loss_gradiend(probs, target_sequence, alpha, C, beta, D, K):
    """
    Compute the gradient of the CTC loss with respect to the input probabilities.

    Args:
        probs (np.ndarray): The input probabilities of shape (T, S).
        target_sequence (list): The target sequence of labels.
        alpha (np.ndarray): The forward variables of shape (T, S).
        C (np.ndarray): The normalization vector of shape (T,).
        beta (np.ndarray): The backward variables of shape (T, S).
        D (np.ndarray): The normalization vector of shape (T,).
        K (int): The number of classes.

    Returns:
        np.ndarray: The gradient of the CTC loss with respect to the input probabilities.
    """
    T, S = probs.shape  # Get the number of time steps and states

    # Initialize Q vector: Q[t] = D[t] * (D[t+1]/C[t+1]) * (D[t+2]/C[t+2]) * ... * (D[T-1]/C[T-1]
    Q = np.zeros(T)

    for t in range(T):
        Q[t] = D[t]
        for tp in range(t+1, T):
            Q[t] *= D[tp]/C[tp]

    # Initialize gradient matrix
    grad = np.zeros((T, S))

    for t in range(T):
        for k in range(K):
            grad[t,k] = probs[t,k] - Q[t]/probs[t,k] * sum([alpha[t,s] * beta[t,s] for s in lab(target_sequence, k)])

    return grad

### Final result, putting everything together

In [None]:
# Initialize the model, optimizer, and learning rate
model = BLSTM_CTC(input_dim=26, hidden_dim=100, output_dim=62)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Hyperparameters
num_epochs = 10
print_every = 1

# Training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    model.train()

    for batch_idx, (input, target_sequence, input_length) in enumerate(train_loader):
        # Forward pass
        probs = model(input)

        # Compute forward & backward variables
        alpha, C = compute_forward_values(input_length, target_sequence, probs)
        beta, D  = compute_backward_values(input_length, target_sequence, probs)

        # Compute total probability and gradients
        total_prob = compute_total_probability(alpha)
        ctc_grads = ctc_loss_gradiend(probs, target_sequence, alpha, C, beta, D, probs.size(2))

        # Backpropagation
        probs.backward(gradient=torch.tensor(ctc_grads, dtype=torch.float32))
        
        # === Update parameters immediately (Online Update) ===
        optimizer.step()
        optimizer.zero_grad()  # Clear gradients right after the update

        # === Logging ===
        if (batch_idx + 1) % print_every == 0:
            print(f"Batch {batch_idx + 1} - CTC Loss: {-torch.log(total_prob).item()}")

    print(f"Finished epoch {epoch + 1}")
print("Training complete!")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Initialize the model, optimizer, and data loader
model = ...  # Your model initialization here
optimizer = ...  # Your optimizer initialization here
train_loader = ...  # Your data loader here

# Hyperparameters
num_epochs = ...  # Number of epochs
print_every = ...  # Print frequency

# Training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    model.train()

    for batch_idx, (input, target_sequence) in enumerate(train_loader):

        # NOTE:
        # target_sequence needs to be a list of integers representing the target labels,
        # some preprocessing may be needed to convert it to the required format.
    
        # Forward pass
        probs = model(input)

        # Compute forward & backward variables
        alpha, C = compute_forward_values(probs, target_sequence)
        beta, D  = compute_backward_values(probs, target_sequence)

        # Compute total probability and gradients
        total_prob = compute_total_probability(alpha)
        ctc_grads = ctc_loss_gradient(probs, target_sequence, alpha, C, beta, D, probs.size(2))

        # Backpropagation
        probs.backward(gradient=torch.tensor(ctc_grads, dtype=torch.float32))
        
        # === Update parameters immediately (Online Update) ===
        optimizer.step()
        optimizer.zero_grad()  # Clear gradients right after the update

        # === Logging ===
        if (batch_idx + 1) % print_every == 0:
            print(f"Batch {batch_idx + 1} - CTC Loss: {-torch.log(total_prob).item()}")

    print(f"Finished epoch {epoch + 1}")
print("Training complete!")