In [None]:
!pip install kaggle==1.5.12

In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json  # Set permissions

In [None]:
!mkdir datasets

In [None]:
!kaggle datasets download -d mozillaorg/common-voice -p /content/datasets --force

In [None]:
%cd datasets

/content/datasets


In [None]:
# !unzip common-voice.zip

In [None]:
import pandas as pd

In [None]:
dev_df = pd.read_csv('cv-valid-train.csv')
# dev_df.head()

In [None]:
# dev_df.info()

In [None]:
dev_df = dev_df[:6000]

In [None]:
# from IPython.display import Audio, display
# import os

# def display_audio(audio_path):
#   display(Audio(audio_path))

# audio_directory = 'cv-valid-dev/cv-valid-dev'
# audio_files = [audio for audio in os.listdir(audio_directory)]

# for audio_file in audio_files[2000:2050]:
#   audio_path = os.path.join(audio_directory, audio_file)
#   display_audio(audio_path)


In [None]:
# dev_df['down_votes'].value_counts()

In [None]:
# dev_df[dev_df['down_votes'] == 0][:10]

 => It seems that the difference between upvotes and downvotes doesn't relate to the quality of audios

**Preprocessing steps**



In [None]:
dev_df = dev_df.drop(columns=dev_df.columns[dev_df.columns.get_loc('up_votes') : dev_df.columns.get_loc('duration') + 1])

In [None]:
# dev_df.head()

In [None]:
import torch
import os
import torchaudio
import warnings
warnings.filterwarnings('ignore')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
audio_directory = 'cv-valid-train'

train_size = int(0.8 * len(dev_df))
train_df = dev_df.iloc[:train_size]
test_df = dev_df.iloc[train_size+1:]

train_audio_files = [os.path.join(f) for f in train_df['filename']]
train_texts = train_df['text'].tolist()

test_audio_files = [os.path.join(f) for f in test_df['filename']]
test_texts = test_df['text'].tolist()

In [None]:
# train_df.info()

In [None]:
import re

In [None]:
def clean_text(text):
    # Keep original case - REMOVED .lower()
    text = text.lower().strip()

    # Handle apostrophes/contractions carefully
    text = re.sub(r"([!\#$%&()*\+,-./:;<=>?@\\\[\]^_`{|}~])", r" \1 ", text)
    text = re.sub("[^A-Za-z0-9]+", " ", text)
    text = re.sub(" +", " ", text)

    return text

In [None]:
from tokenizers import Tokenizer, models, trainers
from tokenizers import pre_tokenizers

In [None]:
# Initialize with byte-level BPE (better for names/contractions)
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))

# Trainer with larger vocab
trainer = trainers.BpeTrainer(
    special_tokens=["<pad>", "<unk>", "<bos>", "<eos>", "<blank>"],
    vocab_size=10000,  # Increased for better word coverage
    min_frequency=3,
    initial_alphabet=pre_tokenizers.ByteLevel.alphabet()  # Better for special characters
)

# Train on properly cleaned text
tokenizer.train_from_iterator([clean_text(t) for t in train_texts], trainer)
tokenizer.save("cv_tokenizer.json")

In [None]:
# tokenizer.get_vocab_size()

In [None]:
import math

In [None]:
def preprocess_for_rnnt_torchaudio(file_paths, target_sr=16000, max_duration=1.5, device='cuda'):
    """GPU-accelerated batch preprocessing pipeline for RNN-T
    Returns:
        log_mel: (batch_size, T, 80) tensor of log Mel spectrograms
        lengths: (batch_size,) tensor of actual lengths in frames
    """
    # Initialize mel transform once
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=target_sr,
        n_fft=400,
        hop_length=160,
        win_length=400,
        n_mels=80,
    ).to(device)

    mel_specs = []
    audio_lengths = []

    # Calculate max_len in frames based on max_duration
    max_len = int(max_duration * target_sr / 160)  # hop_length is 160

    for file_path in file_paths:
        # 1. Load and resample
        waveform, sr = torchaudio.load(file_path)
        waveform = waveform.to(device)

        if sr != target_sr:
            waveform = torchaudio.functional.resample(waveform, sr, target_sr)

        # 2. Normalize
        waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-8)

        # 3. Mel transform
        mel_spec = mel_transform(waveform)  # (channels, n_mels, time)
        log_mel = torch.log(mel_spec + 1e-8)
        mean = log_mel.mean(dim=2, keepdim=True)
        std = log_mel.std(dim=2, keepdim=True) + 1e-8
        log_mel = (log_mel - mean) / std
        # Permute to (channels, time, n_mels)
        log_mel = log_mel.permute(0, 2, 1)
        # 4. Padding/trimming
        if log_mel.size(1) > max_len:
            log_mel = log_mel[:, :max_len+1, :]
        else:
            pad_amount = max_len + 1 - log_mel.size(1)
            log_mel = torch.nn.functional.pad(log_mel, (0, 0, 0, pad_amount))

        mel_specs.append(log_mel.squeeze(0))  # Remove channel dimension if single channel
        audio_lengths.append(log_mel.size(1))

    # Stack into batch (batch_size, T, n_mels)
    mel_specs = torch.stack(mel_specs)
    audio_lengths = torch.tensor(audio_lengths, dtype=torch.int32)

    return mel_specs, audio_lengths

In [None]:
class Encoder(torch.nn.Module):
    def __init__(self, input_dim=80, hidden_dim=128, num_layers=2):
        super().__init__()
        self.lstm = torch.nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True,
            dropout=0.1 if num_layers > 1 else 0
        )
        self.linear = torch.nn.Linear(hidden_dim * 2, hidden_dim)
        self.dropout = torch.nn.Dropout(0.1)

    def forward(self, x):
        # x: (B, T, input_dim)
        x, _ = self.lstm(x)  # (B, T, 2*H)
        x = self.dropout(x)
        x = self.linear(x)   # (B, T, H)
        return x

class Decoder(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=128):
        super().__init__()
        self.embed = torch.nn.Embedding(vocab_size, embed_dim)
        self.lstm = torch.nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.dropout = torch.nn.Dropout(0.1)

    def forward(self, y):
        # y: (B, U)
        y = self.embed(y)     # (B, U, E)
        y, _ = self.lstm(y)   # (B, U, H)
        y = self.dropout(y)
        return y

class JointNetwork(torch.nn.Module):
    def __init__(self, hidden_dim, vocab_size):
        super().__init__()
        # Use a more robust approach with a multi-layer network
        self.joint = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim * 2, hidden_dim),
            torch.nn.Tanh(),
            torch.nn.Linear(hidden_dim, vocab_size)
        )

    def forward(self, h_enc, h_dec):
        # h_enc: (B, T, H), h_dec: (B, U, H)

        # Get batch size and sequence lengths
        batch_size = h_enc.size(0)
        T = h_enc.size(1)  # Audio sequence length
        U = h_dec.size(1)  # Text sequence length

        # Expand dimensions for broadcasting
        h_enc = h_enc.unsqueeze(2)  # (B, T, 1, H)
        h_dec = h_dec.unsqueeze(1)  # (B, 1, U, H)

        # Expand to create the full cartesian product for alignment
        h_enc = h_enc.expand(-1, -1, U, -1)  # (B, T, U, H)
        h_dec = h_dec.expand(-1, T, -1, -1)  # (B, T, U, H)

        # Concatenate features
        joint = torch.cat([h_enc, h_dec], dim=-1)  # (B, T, U, 2H)

        # Apply joint network to get logits
        logits = self.joint(joint)  # (B, T, U, V)

        return logits

class RNNTransducer(torch.nn.Module):
    def __init__(self, vocab_size, encoder_dim=128, decoder_dim=128):
        super().__init__()
        self.encoder = Encoder(hidden_dim=encoder_dim)
        self.decoder = Decoder(vocab_size, hidden_dim=decoder_dim)
        self.joint = JointNetwork(decoder_dim, vocab_size)
        self.vocab_size = vocab_size

    def forward(self, x, y):
        # x: (B, T, features), y: (B, U)
        h_enc = self.encoder(x)  # (B, T, H)
        h_dec = self.decoder(y)  # (B, U, H)

        # Both should now have the same hidden dimension
        logits = self.joint(h_enc, h_dec)  # (B, T, U, V)
        return logits



# First, let's prepare the dataset class
class AudioTextDataset(torch.utils.data.Dataset):
    def __init__(self, audio_files, texts, tokenizer, max_text_length=100,
                 target_sr=16000, max_duration=1.5):
        """
        Dataset class that handles:
        - File paths for batch audio processing
        - Text sequences with tokenization
        - Automatic audio preprocessing in batches
        """
        self.audio_files = [str(f) for f in audio_files]  # Ensure string paths
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_text_length = max_text_length
        self.target_sr = target_sr
        self.max_duration = max_duration

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        """
        Returns single item with:
        - audio_file: Path for batch processing
        - text: Raw text string for tokenization in collate
        """
        return {
            'audio_file': self.audio_files[idx],
            'text': self.texts[idx],
            'tokenizer': self.tokenizer,  # Pass tokenizer for collate_fn
            'max_text_length': self.max_text_length
        }

def collate_fn(batch, device='cuda'):
    """
    Custom collate function that:
    1. Processes audio files in batch using preprocess_for_rnnt_torchaudio
    2. Tokenizes and pads text sequences
    3. Handles audio length preservation
    """
    # Extract batch components
    audio_files = [item['audio_file'] for item in batch]
    texts = [item['text'] for item in batch]

    # Batch process audio files
    audio_features, audio_lengths = preprocess_for_rnnt_torchaudio(
        audio_files,
        target_sr=batch[0].get('target_sr', 16000),
        max_duration=batch[0].get('max_duration', 1.5),
        device=device
    )  # (B, T, 80)

    target_lengths = [item['max_text_length'] - 2 for item in batch]

    # Tokenize and pad texts
    text_sequences = []
    for item in batch:
        encoding = item['tokenizer'].encode(item['text'])
        text_ids = encoding.ids

        # Add special tokens
        text_ids = [
            item['tokenizer'].token_to_id("<bos>")
        ] + text_ids + [
            item['tokenizer'].token_to_id("<eos>")
        ]

        # Pad or truncate
        max_len = item['max_text_length']
        if len(text_ids) < max_len:
            text_ids = text_ids + [item['tokenizer'].token_to_id("<pad>")] * (max_len - len(text_ids))
        else:
            text_ids = text_ids[:max_len]

        text_sequences.append(torch.tensor(text_ids, dtype=torch.int32))

    target_lengths_tensor = torch.tensor(target_lengths, dtype=torch.int32)
    # Stack text sequences
    text_tensor = torch.stack(text_sequences)

    return {
        'audio': audio_features.to(device),  # (B, T, 80)
        'text': text_tensor.to(device),     # (B, U_max)
        'target_lengths': target_lengths_tensor.to(device),
        'audio_lengths': torch.tensor(audio_lengths, device=device) # (B,)
    }

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler, device, tokenizer, mixed_precision=True):
    model.train()
    total_loss = 0
    steps = 0

    # Enable mixed precision for faster training and lower memory usage
    scaler = torch.cuda.amp.GradScaler() if mixed_precision else None

    for batch_idx, batch in enumerate(dataloader):
        # Get batch data
        audio = batch['audio']  # (B, T, 80)
        text = batch['text']    # (B, U)

        # Get lengths
        input_lengths = batch['audio_lengths']
        target_lengths = batch['target_lengths']

        # Clear gradients
        optimizer.zero_grad()

        if mixed_precision:
            with torch.cuda.amp.autocast():
                # Forward pass - encoder
                encoder_output = model.encoder(audio)

                # Forward pass - decoder and joint
                h_dec = model.decoder(text[:, :-1])  # (B, U-1, H)
                logits = model.joint(encoder_output, h_dec)  # (B, T, U-1, V)

                # For targets, use labels shifted right (exclude BOS)
                targets = text[:, 1:-1].contiguous()  # Remove only BOS for targets

                # Compute loss
                blank_id = tokenizer.token_to_id('<blank>')
                loss = torchaudio.functional.rnnt_loss(
                    logits=torch.nn.functional.log_softmax(logits, dim=-1),
                    targets=targets,
                    logit_lengths=input_lengths,
                    target_lengths=target_lengths,
                    blank=blank_id,
                    reduction='mean'
                )

            # Backward pass with scaled gradients
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Lower clip threshold
            scaler.step(optimizer)
            scaler.update()
        else:
            # Standard training without mixed precision
            encoder_output = model.encoder(audio)
            h_dec = model.decoder(text[:, :-1])
            logits = model.joint(encoder_output, h_dec)

            targets = text[:, 1:-1].contiguous()

            blank_id = tokenizer.token_to_id('<blank>')
            loss = torchaudio.functional.rnnt_loss(
                logits=torch.nn.functional.log_softmax(logits, dim=-1),
                targets=targets,
                logit_lengths=input_lengths,
                target_lengths=target_lengths,
                blank=blank_id,
                reduction='mean'
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        # Update learning rate with scheduler
        if scheduler is not None:
            scheduler.step()

        total_loss += loss.item()
        steps += 1

    return total_loss / steps

In [None]:
tokenizer = Tokenizer.from_file("cv_tokenizer.json")

In [None]:
# Create datasets
train_dataset = AudioTextDataset(
    audio_files=[os.path.join(audio_directory, f) for f in train_audio_files],
    texts=train_texts,
    tokenizer=tokenizer,
)

In [None]:
test_dataset = AudioTextDataset(
    audio_files=[os.path.join(audio_directory, f) for f in test_audio_files],
    texts=test_texts,
    tokenizer=tokenizer,
)

In [None]:
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=True,
        collate_fn=lambda batch: collate_fn(batch, device=device)
    )

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=lambda batch: collate_fn(batch, device=device)
)

In [None]:
# Initialize model
model = RNNTransducer(vocab_size=tokenizer.get_vocab_size()).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4,  # L2 regularization
        betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=20e-3,
        steps_per_epoch=len(train_loader),
        epochs=20,
        pct_start=0.2,  # Warm-up for first 10% of training
        div_factor=10,  # Initial LR = max_lr/div_factor
        final_div_factor=50,  # Final LR = initial_lr/final_div_factor
    )

# Training loop
for epoch in range(20):
    # Train for one epoch
    train_loss = train_epoch(
        model=model,
        dataloader=train_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        tokenizer=tokenizer,
        mixed_precision=True
    )

    print(f"Train Loss: {train_loss:.4f}")

    # Save checkpoint
    # if args.save_dir:
    #     os.makedirs(args.save_dir, exist_ok=True)
    #     checkpoint_path = os.path.join(args.save_dir, f"rnnt_epoch_{epoch+1}.pt")
    #     torch.save({
    #         'epoch': epoch,
    #         'model_state_dict': model.state_dict(),
    #         'optimizer_state_dict': optimizer.state_dict(),
    #         'loss': train_loss,
    #     }, checkpoint_path)
    #     print(f"Saved checkpoint to {checkpoint_path}")

print("Training complete!")

Train Loss: 104.7497
Train Loss: 47.8600
Train Loss: 43.7264
Train Loss: 37.5052
Train Loss: 31.5323
Train Loss: 26.5488
Train Loss: 22.9217
Train Loss: 20.2552
Train Loss: 17.8292
Train Loss: 15.9396
Train Loss: 14.6135
Train Loss: 13.3380
Train Loss: 12.5421
Train Loss: 11.7528
Train Loss: 11.2299


In [None]:
from typing import List, Union

In [None]:
def predict_rnnt(
    model: torch.nn.Module,
    audio_features: torch.Tensor,
    tokenizer: object,
    device: Union[str, torch.device],
    max_decode_len: int = 100,
    method: str = "greedy",
    beam_width: int = 5,
) -> str:
    model.eval()
    model.to(device)

    # Add batch dimension if missing (shape: [1, T, D])
    if len(audio_features.shape) == 2:
        audio_features = audio_features.unsqueeze(0).to(device)
    else:
        audio_features = audio_features.to(device)

    # Initialize with BOS token
    hyps = [[tokenizer.token_to_id("<bos>")]]

    with torch.no_grad():
        h_enc = model.encoder(audio_features)  # [1, T, H]
        # print("h_enc shape: ", h_enc.shape)

        if method == "greedy":
            for _ in range(max_decode_len):
                last_token = torch.tensor([hyps[0][-1]], device=device).unsqueeze(0)
                # print("last token shape: ", last_token.shape)
                h_dec = model.decoder(last_token)  # [1, 1, H]
                # print("h_dec shape", h_dec.shape)
                logits = model.joint(h_enc, h_dec)  # [1, T, 1, V]

                # --- FIX: Sum over time before argmax ---
                log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                scores = log_probs.sum(dim=1)
                # print("log probs: ", log_probs)
                # scores = log_probs.sum(dim=1)  # [1, 1, V]
                # print("scores: ", scores)
                next_token = scores.argmax(-1)[0, -1].item()  # Now a scalar
                # ---------------------------------------
                # print("Next token: ", next_token)
                if next_token == tokenizer.token_to_id("<eos>"):
                    break
                if next_token != tokenizer.token_to_id('<blank>'):  # Skip blank tokens
                    hyps[0].append(next_token)

            print("hyps: ", hyps)
            return tokenizer.decode(hyps[0][1:])
        else:
            raise ValueError(f"Unknown decoding method: {method}")


In [None]:
def predict_rnnt_beam_search(
    model: torch.nn.Module,
    audio_features: torch.Tensor,
    tokenizer: object,
    device: Union[str, torch.device],
    beam_width: int = 5,
    max_decode_len: int = 100,
) -> str:
    """
    Performs beam search decoding for an RNN-T model.

    Args:
        model: The RNN-T model
        audio_features: The audio features tensor
        tokenizer: The tokenizer object
        device: The device to run inference on
        beam_width: Width of the beam search
        max_decode_len: Maximum decoding length

    Returns:
        The best decoded transcript
    """
    model.eval()
    model.to(device)

    # Add batch dimension if missing (shape: [1, T, D])
    if len(audio_features.shape) == 2:
        audio_features = audio_features.unsqueeze(0).to(device)
    else:
        audio_features = audio_features.to(device)

    with torch.no_grad():
        # Encode the audio features
        h_enc = model.encoder(audio_features)  # [1, T, H]

        # Initialize the beam
        # Each beam element: (negative log prob, sequence, decoder state)
        bos_token = tokenizer.token_to_id("<bos>")
        eos_token = tokenizer.token_to_id("<eos>")
        blank_token = tokenizer.token_to_id("<blank>")

        # Initialize beam with just the beginning token
        current_idx = torch.tensor([[bos_token]], device=device)
        initial_decoder_output = model.decoder(current_idx)  # [1, 1, H]

        # Start with a single beam
        beams = [(0.0, [bos_token], initial_decoder_output)]
        finished_beams = []

        # Main beam search loop
        for step in range(max_decode_len):
            candidates = []

            # Expand each current beam
            for log_prob, sequence, decoder_output in beams:
                # Skip if sequence already ended
                if sequence[-1] == eos_token:
                    finished_beams.append((log_prob, sequence))
                    continue

                # Run joint network on the encoded audio and last decoder output
                logits = model.joint(h_enc, decoder_output)  # [1, T, 1, V]

                # Sum over time dimension for alignment scores
                log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                scores = log_probs.sum(dim=1).squeeze(1)  # [1, V]

                # Get top-k tokens
                top_log_probs, top_indices = scores.topk(beam_width)

                # Add candidates to our list
                for i in range(beam_width):
                    token_id = top_indices[0, i].item()
                    token_log_prob = top_log_probs[0, i].item()

                    # Skip blank tokens in final output but consider them for alignment
                    if token_id == blank_token:
                        candidates.append((log_prob + token_log_prob, sequence.copy(), decoder_output))
                    else:
                        new_sequence = sequence.copy() + [token_id]

                        # Update decoder state with new token
                        new_token = torch.tensor([[token_id]], device=device)
                        new_decoder_output = model.decoder(new_token)

                        candidates.append((log_prob + token_log_prob, new_sequence, new_decoder_output))

            # Keep only the best beam_width candidates
            candidates.sort(key=lambda x: x[0], reverse=True)  # Sort by log probability (higher is better)
            beams = candidates[:beam_width]

            # Early stopping if all beams have finished
            if all(beam[1][-1] == eos_token for beam in beams):
                break

        # Add any unfinished beams to finished_beams
        for log_prob, sequence, _ in beams:
            if sequence[-1] != eos_token:
                finished_beams.append((log_prob, sequence))

        # Select the most likely beam
        if finished_beams:
            best_beam = max(finished_beams, key=lambda x: x[0])
            best_sequence = best_beam[1]
        else:
            best_sequence = beams[0][1]  # Best unfinished beam

        # Remove special tokens (BOS, EOS) for the final output
        filtered_sequence = [token for token in best_sequence
                            if token != bos_token and token != eos_token]

        return tokenizer.decode(filtered_sequence)

In [None]:
len(test_audio_files)

In [None]:
# Load model, tokenizer, and features
model = model.eval()  # Your RNN-T model
audio_features, _ = preprocess_for_rnnt_torchaudio([os.path.join(audio_directory, test_audio_files[1])])  # Shape: [T, D]

with torch.no_grad():
  text_greedy = predict_rnnt(model, audio_features, tokenizer, "cuda")
  print(f"Greedy: {text_greedy}")
  text_beam = predict_rnnt_beam_search(
      model, audio_features, tokenizer, "cuda")
  print(f"Beam: {text_beam}")

In [None]:
test_texts[1]