In [1]:
# !pip install kaggle==1.5.12

In [2]:
# !mkdir -p ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json  # Set permissions

In [3]:
# !mkdir datasets

In [4]:
# !kaggle datasets download -d mozillaorg/common-voice -p /content/datasets --force

In [5]:
%cd datasets

/content/datasets


In [6]:
# !unzip common-voice.zip

In [7]:
import pandas as pd

In [8]:
dev_df = pd.read_csv('cv-valid-dev.csv')
# dev_df.head()

In [9]:
# dev_df.info()

In [10]:
# !pip install pydub

In [11]:
# from IPython.display import Audio, display
# import os

# def display_audio(audio_path):
#   display(Audio(audio_path))

# audio_directory = 'cv-valid-dev/cv-valid-dev'
# audio_files = [audio for audio in os.listdir(audio_directory)]

# for audio_file in audio_files[2000:2050]:
#   audio_path = os.path.join(audio_directory, audio_file)
#   display_audio(audio_path)


In [12]:
# dev_df['down_votes'].value_counts()

In [13]:
# dev_df[dev_df['down_votes'] == 0][:10]

 => It seems that the difference between upvotes and downvotes doesn't relate to the quality of audios

**Preprocessing steps**



In [14]:
dev_df = dev_df.drop(columns=dev_df.columns[dev_df.columns.get_loc('up_votes') : dev_df.columns.get_loc('duration') + 1])

In [15]:
# dev_df.head()

In [16]:
import torch
import os
import torchaudio
import warnings
warnings.filterwarnings('ignore')

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [18]:
audio_directory = 'cv-valid-dev'

train_size = int(0.8 * len(dev_df))
train_df = dev_df.iloc[:train_size]
test_df = dev_df.iloc[train_size+1:]

train_audio_files = [os.path.join(f) for f in train_df['filename']]
train_texts = train_df['text'].tolist()

test_audio_files = [os.path.join(f) for f in test_df['filename']]
test_texts = test_df['text'].tolist()

In [19]:
import re

In [20]:
def clean_text(text):
    # Keep original case - REMOVED .lower()
    text = text.lower().strip()

    # Handle apostrophes/contractions carefully
    text = re.sub(r"([!\"'#$%&()*\+,-./:;<=>?@\\\[\]^_`{|}~])", r" \1 ", text)
    text = re.sub("[^A-Za-z0-9]+", " ", text)
    text = re.sub(" +", " ", text)

    return text

In [21]:
from tokenizers import Tokenizer, models, trainers
from tokenizers import pre_tokenizers

In [22]:
# Initialize with byte-level BPE (better for names/contractions)
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))

# Trainer with larger vocab
trainer = trainers.BpeTrainer(
    special_tokens=["<pad>", "<unk>", "<bos>", "<eos>", "<blank>"],
    vocab_size=6150,  # Increased for better word coverage
    min_frequency=3,
    initial_alphabet=pre_tokenizers.ByteLevel.alphabet()  # Better for special characters
)

# Train on properly cleaned text
tokenizer.train_from_iterator([clean_text(t) for t in train_texts], trainer)
tokenizer.save("cv_tokenizer.json")

In [23]:
import math

In [24]:
def preprocess_for_rnnt_torchaudio(file_paths, target_sr=16000, max_duration=1.5, device='cuda'):
    """GPU-accelerated batch preprocessing pipeline for RNN-T
    Returns:
        log_mel: (batch_size, T, 80) tensor of log Mel spectrograms
        lengths: (batch_size,) tensor of actual lengths in frames
    """
    # Initialize mel transform once
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=target_sr,
        n_fft=400,
        hop_length=160,
        win_length=400,
        n_mels=80,
    ).to(device)

    mel_specs = []
    audio_lengths = []

    # Calculate max_len in frames based on max_duration
    max_len = int(max_duration * target_sr / 160)  # hop_length is 160

    for file_path in file_paths:
        # 1. Load and resample
        waveform, sr = torchaudio.load(file_path)
        waveform = waveform.to(device)

        if sr != target_sr:
            waveform = torchaudio.functional.resample(waveform, sr, target_sr)

        # 2. Normalize
        waveform = torch.nn.functional.normalize(waveform, dim=-1)

        # 3. Mel transform
        mel_spec = mel_transform(waveform)  # (channels, n_mels, time)
        log_mel = torch.log(mel_spec + 1e-6)
        log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-6)

        # Permute to (channels, time, n_mels)
        log_mel = log_mel.permute(0, 2, 1)
        # print("Log mel: ", log_mel.shape)
        # 4. Padding/trimming
        if log_mel.size(1) > max_len:
            log_mel = log_mel[:, :max_len+1, :]
        else:
            pad_amount = max_len + 1 - log_mel.size(1)
            log_mel = torch.nn.functional.pad(log_mel, (0, 0, 0, pad_amount))

        mel_specs.append(log_mel.squeeze(0))  # Remove channel dimension if single channel
        audio_lengths.append(log_mel.size(1))

    # Stack into batch (batch_size, T, n_mels)
    mel_specs = torch.stack(mel_specs)
    audio_lengths = torch.tensor(audio_lengths, dtype=torch.int32)

    return mel_specs, audio_lengths

In [25]:
class Encoder(torch.nn.Module):
    def __init__(self, input_dim=80, hidden_dim=256, num_layers=3):
        super().__init__()
        self.lstm = torch.nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True,
            dropout=0.1 if num_layers > 1 else 0
        )
        self.linear = torch.nn.Linear(hidden_dim * 2, hidden_dim)
        self.dropout = torch.nn.Dropout(0.1)

    def forward(self, x):
        # x: (B, T, input_dim)
        x, _ = self.lstm(x)  # (B, T, 2*H)
        x = self.dropout(x)
        x = self.linear(x)   # (B, T, H)
        return x

class Decoder(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256):
        super().__init__()
        self.embed = torch.nn.Embedding(vocab_size, embed_dim)
        self.lstm = torch.nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.dropout = torch.nn.Dropout(0.1)

    def forward(self, y):
        # y: (B, U)
        y = self.embed(y)     # (B, U, E)
        y, _ = self.lstm(y)   # (B, U, H)
        y = self.dropout(y)
        return y

class JointNetwork(torch.nn.Module):
    def __init__(self, hidden_dim, vocab_size):
        super().__init__()
        # Use a more robust approach with a multi-layer network
        self.joint = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim * 2, hidden_dim),
            torch.nn.Tanh(),
            torch.nn.Linear(hidden_dim, vocab_size)
        )

    def forward(self, h_enc, h_dec):
        # h_enc: (B, T, H), h_dec: (B, U, H)

        # Get batch size and sequence lengths
        batch_size = h_enc.size(0)
        T = h_enc.size(1)  # Audio sequence length
        U = h_dec.size(1)  # Text sequence length

        # Expand dimensions for broadcasting
        h_enc = h_enc.unsqueeze(2)  # (B, T, 1, H)
        h_dec = h_dec.unsqueeze(1)  # (B, 1, U, H)

        # Expand to create the full cartesian product for alignment
        h_enc = h_enc.expand(-1, -1, U, -1)  # (B, T, U, H)
        h_dec = h_dec.expand(-1, T, -1, -1)  # (B, T, U, H)

        # Concatenate features
        joint = torch.cat([h_enc, h_dec], dim=-1)  # (B, T, U, 2H)

        # Apply joint network to get logits
        logits = self.joint(joint)  # (B, T, U, V)

        return logits

class RNNTransducer(torch.nn.Module):
    def __init__(self, vocab_size, encoder_dim=256, decoder_dim=256):
        super().__init__()
        self.encoder = Encoder(hidden_dim=encoder_dim)
        self.decoder = Decoder(vocab_size, hidden_dim=decoder_dim)
        self.joint = JointNetwork(decoder_dim, vocab_size)
        self.vocab_size = vocab_size

    def forward(self, x, y):
        # x: (B, T, features), y: (B, U)
        h_enc = self.encoder(x)  # (B, T, H)
        h_dec = self.decoder(y)  # (B, U, H)

        # Both should now have the same hidden dimension
        logits = self.joint(h_enc, h_dec)  # (B, T, U, V)
        return logits



# First, let's prepare the dataset class
class AudioTextDataset(torch.utils.data.Dataset):
    def __init__(self, audio_files, texts, tokenizer, max_text_length=100,
                 target_sr=16000, max_duration=1.5):
        """
        Dataset class that handles:
        - File paths for batch audio processing
        - Text sequences with tokenization
        - Automatic audio preprocessing in batches
        """
        self.audio_files = [str(f) for f in audio_files]  # Ensure string paths
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_text_length = max_text_length
        self.target_sr = target_sr
        self.max_duration = max_duration

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        """
        Returns single item with:
        - audio_file: Path for batch processing
        - text: Raw text string for tokenization in collate
        """
        return {
            'audio_file': self.audio_files[idx],
            'text': self.texts[idx],
            'tokenizer': self.tokenizer,  # Pass tokenizer for collate_fn
            'max_text_length': self.max_text_length
        }

def collate_fn(batch, device='cuda'):
    """
    Custom collate function that:
    1. Processes audio files in batch using preprocess_for_rnnt_torchaudio
    2. Tokenizes and pads text sequences
    3. Handles audio length preservation
    """
    # Extract batch components
    audio_files = [item['audio_file'] for item in batch]
    texts = [item['text'] for item in batch]

    # Batch process audio files
    audio_features, audio_lengths = preprocess_for_rnnt_torchaudio(
        audio_files,
        target_sr=batch[0].get('target_sr', 16000),
        max_duration=batch[0].get('max_duration', 1.5),
        device=device
    )  # (B, T, 80)

    target_lengths = [item['max_text_length'] - 2 for item in batch]

    # Tokenize and pad texts
    text_sequences = []
    for item in batch:
        encoding = item['tokenizer'].encode(item['text'])
        text_ids = encoding.ids

        # Add special tokens
        text_ids = [
            item['tokenizer'].token_to_id("<bos>")
        ] + text_ids + [
            item['tokenizer'].token_to_id("<eos>")
        ]

        # Pad or truncate
        max_len = item['max_text_length']
        if len(text_ids) < max_len:
            text_ids = text_ids + [item['tokenizer'].token_to_id("<pad>")] * (max_len - len(text_ids))
        else:
            text_ids = text_ids[:max_len]

        text_sequences.append(torch.tensor(text_ids, dtype=torch.int32))

    target_lengths_tensor = torch.tensor(target_lengths, dtype=torch.int32)
    # Stack text sequences
    text_tensor = torch.stack(text_sequences)

    return {
        'audio': audio_features.to(device),  # (B, T, 80)
        'text': text_tensor.to(device),     # (B, U_max)
        'target_lengths': target_lengths_tensor.to(device),
        'audio_lengths': torch.tensor(audio_lengths, device=device) # (B,)
    }

In [26]:
def rnnt_loss(logits, targets, input_lengths, target_lengths, blank):
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    loss = torchaudio.functional.rnnt_loss(
        log_probs,
        targets,
        input_lengths,
        target_lengths,
        blank=blank,
        reduction='mean'
    )
    return loss

def train_epoch(model, dataloader, optimizer, device, tokenizer):
    model.train()
    total_loss = 100

    for batch_idx, batch in enumerate(dataloader):
        # Get batch data
        audio = batch['audio']  # (B, T, 80)
        text = batch['text']    # (B, U)

        # Get audio lengths - must match the time dimension of the encoder output
        input_lengths = batch['audio_lengths']
        target_lengths = batch['target_lengths']

        # Ensure input lengths don't exceed the encoder output length
        encoder_output = model.encoder(audio)
        # Forward pass
        optimizer.zero_grad()

        # Pass input sequence and labels (excluding EOS for the decoder input)
        # Use encoder_output directly since we already computed it
        h_dec = model.decoder(text[:, :-1])  # (B, U-1, H)
        logits = model.joint(encoder_output, h_dec)  # (B, T, U-1, V)

        # For targets, use labels shifted right (exclude BOS)
        targets = text[:, 1:-1].contiguous() # Remove BOS and EOS for RNN-T targets

        # Compute loss
        loss = rnnt_loss(
            logits=logits,
            targets=targets,
            input_lengths=input_lengths,
            target_lengths=target_lengths,
            blank=0
        )

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        optimizer.step()

        total_loss += loss.item()
        print("Loss item: ", loss.item())

    return total_loss / max(1, len(dataloader))

In [27]:
tokenizer = Tokenizer.from_file("cv_tokenizer.json")

In [28]:
# Create datasets
train_dataset = AudioTextDataset(
    audio_files=[os.path.join(audio_directory, f) for f in train_audio_files],
    texts=train_texts,
    tokenizer=tokenizer,
)

In [29]:
test_dataset = AudioTextDataset(
    audio_files=[os.path.join(audio_directory, f) for f in test_audio_files],
    texts=test_texts,
    tokenizer=tokenizer,
)

In [30]:
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=4,
        shuffle=True,
        collate_fn=lambda batch: collate_fn(batch, device=device)
    )

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=lambda batch: collate_fn(batch, device=device)
)

In [None]:
 # Initialize model
model = RNNTransducer(vocab_size=tokenizer.get_vocab_size()).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training loop
for epoch in range(5):
    # Train for one epoch
    train_loss = train_epoch(
        model=model,
        dataloader=train_loader,
        optimizer=optimizer,
        device=device,
        tokenizer=tokenizer
    )

    print(f"Train Loss: {train_loss:.4f} ({(1-train_loss)/100:.4f})")

    # Save checkpoint
    # if args.save_dir:
    #     os.makedirs(args.save_dir, exist_ok=True)
    #     checkpoint_path = os.path.join(args.save_dir, f"rnnt_epoch_{epoch+1}.pt")
    #     torch.save({
    #         'epoch': epoch,
    #         'model_state_dict': model.state_dict(),
    #         'optimizer_state_dict': optimizer.state_dict(),
    #         'loss': train_loss,
    #     }, checkpoint_path)
    #     print(f"Saved checkpoint to {checkpoint_path}")

print("Training complete!")

Train Loss: -29.1708 (0.3017)
Train Loss: -101.1672 (1.0217)
Train Loss: -103.0302 (1.0403)
Train Loss: -104.3939 (1.0539)


In [None]:
from typing import List, Union

In [None]:
def predict_rnnt(
    model: torch.nn.Module,
    audio_features: torch.Tensor,
    tokenizer: object,
    device: Union[str, torch.device],
    max_decode_len: int = 100,
    method: str = "greedy",
    beam_width: int = 5,
) -> str:
    model.eval()
    model.to(device)

    # Add batch dimension if missing (shape: [1, T, D])
    if len(audio_features.shape) == 2:
        audio_features = audio_features.unsqueeze(0).to(device)
    else:
        audio_features = audio_features.to(device)

    # Initialize with BOS token
    hyps = [[tokenizer.token_to_id("<bos>")]]

    with torch.no_grad():
        h_enc = model.encoder(audio_features)  # [1, T, H]
        # print("h_enc shape: ", h_enc.shape)

        if method == "greedy":
            for _ in range(max_decode_len):
                last_token = torch.tensor([hyps[0][-1]], device=device).unsqueeze(0)
                # print("last token shape: ", last_token.shape)
                h_dec = model.decoder(last_token)  # [1, 1, H]
                # print("h_dec shape", h_dec.shape)
                logits = model.joint(h_enc, h_dec)  # [1, T, 1, V]

                # --- FIX: Sum over time before argmax ---
                log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                scores = log_probs.squeeze(2)
                # print("log probs: ", log_probs)
                # scores = log_probs.sum(dim=1)  # [1, 1, V]
                # print("scores: ", scores)
                next_token = scores.argmax(-1)[0, -1].item()  # Now a scalar
                # ---------------------------------------
                # print("Next token: ", next_token)
                if next_token == tokenizer.token_to_id("<eos>"):
                    break
                if next_token != 0:  # Skip blank tokens
                    hyps[0].append(next_token)

            print("hyps: ", hyps)
            return tokenizer.decode(hyps[0][1:])

        elif method == "beam_search":
            beam_hyps = hyps
            beam_scores = torch.zeros(1, device=device)

            for _ in range(max_decode_len):
                candidates = []
                for i, hyp in enumerate(beam_hyps):
                    last_token = torch.tensor([hyp[-1]], device=device).unsqueeze(0)
                    h_dec = model.decoder(last_token)
                    logits = model.joint(h_enc, h_dec)  # [1, T, 1, V]
                    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

                    # --- FIX: Sum over time before top-k ---
                    scores = log_probs.sum(dim=1)  # [1, 1, V]
                    topk_scores, topk_tokens = scores.topk(beam_width, dim=-1)
                    # ---------------------------------------

                    for j in range(beam_width):
                        new_hyp = hyp + [topk_tokens[0, 0, j].item()]
                        new_score = beam_scores[i] + topk_scores[0, 0, j]
                        candidates.append((new_hyp, new_score))

                candidates.sort(key=lambda x: x[1], reverse=True)
                beam_hyps = [x[0] for x in candidates[:beam_width]]
                beam_scores = torch.tensor([x[1] for x in candidates[:beam_width]], device=device)

                if all(h[-1] == tokenizer.token_to_id("<eos>") for h in beam_hyps):
                    break

            return tokenizer.decode(beam_hyps[0][1:])
        else:
            raise ValueError(f"Unknown decoding method: {method}")


In [None]:
# Load model, tokenizer, and features
model = model.eval()  # Your RNN-T model
audio_features, _ = preprocess_for_rnnt_torchaudio([os.path.join(audio_directory, test_audio_files[-1])])  # Shape: [T, D]

print("Audio shape", audio_features.shape)
# Greedy decoding (fast)
text_greedy = predict_rnnt(
    model, audio_features, tokenizer, "cuda", method="greedy"
)

# Beam search (better accuracy)
text_beam = predict_rnnt(
    model, audio_features, tokenizer, "cuda", method="beam_search", beam_width=5
)

print(f"Greedy: {text_greedy}\nBeam Search: {text_beam}")