In [1]:
# !pip install kaggle==1.5.12

In [2]:
# !mkdir -p ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json  # Set permissions

In [3]:
# !mkdir datasets

In [4]:
# !kaggle datasets download -d mozillaorg/common-voice -p /content/datasets --force

In [5]:
%cd datasets

/content/datasets


In [6]:
# !unzip common-voice.zip

In [7]:
import pandas as pd

In [8]:
dev_df = pd.read_csv('cv-valid-dev.csv')
# dev_df.head()

In [9]:
# dev_df.info()

In [10]:
# !pip install pydub

In [11]:
# from IPython.display import Audio, display
# import os

# def display_audio(audio_path):
#   display(Audio(audio_path))

# audio_directory = 'cv-valid-dev/cv-valid-dev'
# audio_files = [audio for audio in os.listdir(audio_directory)]

# for audio_file in audio_files[:5]:
#   audio_path = os.path.join(audio_directory, audio_file)
#   display_audio(audio_path)


In [12]:
# dev_df['down_votes'].value_counts()

In [13]:
# dev_df[dev_df['down_votes'] == 0][:10]

 => It seems that the difference between upvotes and downvotes doesn't relate to the quality of audios

**Preprocessing steps**



In [14]:
dev_df = dev_df.drop(columns=dev_df.columns[dev_df.columns.get_loc('up_votes') : dev_df.columns.get_loc('duration') + 1])

In [15]:
# dev_df.head()

In [16]:
import torch
import os
import torchaudio
import warnings
warnings.filterwarnings('ignore')

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [18]:
audio_directory = 'cv-valid-dev'

train_size = int(0.8 * len(dev_df))
train_df = dev_df.iloc[:train_size]
test_df = dev_df.iloc[train_size:]

train_audio_files = [os.path.join(f) for f in train_df['filename']]
train_texts = train_df['text'].tolist()

test_audio_files = [os.path.join(f) for f in test_df['filename']]
test_texts = test_df['text'].tolist()

In [19]:
# train_df, test_df = torch.utils.data.random_split(dev_df, [0.8, 0.2])

In [20]:
# train_texts = [dev_df.loc[idx, 'text'] for idx in train_df.indices]
# test_texts = [dev_df.loc[idx, 'text'] for idx in test_df.indices]

In [21]:
# texts[:5]

In [22]:
import re

In [23]:
def clean_text(text):
    # Keep original case - REMOVED .lower()
    text = text.lower().strip()

    # Handle apostrophes/contractions carefully
    text = re.sub(r"([!\"'#$%&()*\+,-./:;<=>?@\\\[\]^_`{|}~])", r" \1 ", text)
    text = re.sub("[^A-Za-z0-9]+", " ", text)
    text = re.sub(" +", " ", text)

    return text

In [24]:
from tokenizers import Tokenizer, models, trainers
from tokenizers import pre_tokenizers

In [25]:
# Initialize with byte-level BPE (better for names/contractions)
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))

# Trainer with larger vocab
trainer = trainers.BpeTrainer(
    special_tokens=["<pad>", "<unk>", "<bos>", "<eos>", "<blank>"],
    vocab_size=6150,  # Increased for better word coverage
    min_frequency=3,
    initial_alphabet=pre_tokenizers.ByteLevel.alphabet()  # Better for special characters
)

# Train on properly cleaned text
tokenizer.train_from_iterator([clean_text(t) for t in train_texts], trainer)
tokenizer.save("cv_tokenizer.json")

In [26]:
import math

In [27]:
def preprocess_for_rnnt_torchaudio(file_paths, target_sr=16000, max_duration=5, device='cpu'):
    """GPU-accelerated batch preprocessing pipeline for RNN-T
    Returns:
        log_mel: (batch_size, T, 80) tensor of log Mel spectrograms
        lengths: (batch_size,) tensor of actual lengths in frames
    """
    waveforms = []
    original_lengths = []

    for file_path in file_paths:
        # Load and process each audio file
        waveform, sr = torchaudio.load(file_path)
        waveform = waveform.to(device)

        if sr != target_sr:
            waveform = torchaudio.functional.resample(waveform, sr, target_sr)

        # waveform = torchaudio.functional.vad(waveform, sample_rate=target_sr, trigger_level=20)
        waveform = torch.nn.functional.normalize(waveform, dim=-1)

        # Track original length after VAD but before padding/trimming
        original_length = waveform.size(-1)

        # Fixed-length processing
        max_samples = target_sr * max_duration
        if waveform.size(-1) > max_samples:
            waveform = waveform[..., :max_samples]
        else:
            pad_amount = max_samples - waveform.size(-1)
            waveform = torch.nn.functional.pad(waveform, (0, pad_amount))

        waveforms.append(waveform.squeeze(0))  # Remove channel dimension
        original_lengths.append(original_length)

    # Stack waveforms into (batch_size, num_samples)
    waveforms = torch.stack(waveforms).to(device)
    # original_lengths = torch.tensor(original_lengths, device=device)

    # Mel spectrogram transform
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=target_sr,
        n_fft=400,
        hop_length=160,
        win_length=400,
        n_mels=80,
    ).to(device)

    hop_length = 160
    frame_lengths = torch.tensor([
        math.ceil(length / hop_length) for length in original_lengths
    ], dtype=torch.int32, device=device)

    # Process entire batch
    mel_spec = mel_transform(waveforms)  # (batch_size, 80, T)
    log_mel = torch.log(mel_spec + 1e-6)

    return log_mel.permute(0, 2, 1), frame_lengths  # (batch_size, T, 80), (batch_size,)

In [28]:
# # Load tokenizer
# tokenizer = Tokenizer.from_file("cv_tokenizer.json")

# # Test encoding
# sample_text = "hello mr. sunshine!"
# cleaned_sample = clean_text(sample_text)
# print("Cleaned sample: ", cleaned_sample)
# encoding = tokenizer.encode(sample_text.lower())
# print("Tokens:", encoding.tokens)
# print("Length: ", len(encoding.tokens))
# print("IDs:", encoding.ids)
# print("Length: ", len(encoding.ids))

In [29]:
class Encoder(torch.nn.Module):
    def __init__(self, input_dim=80, hidden_dim=256, num_layers=3):
        super().__init__()
        self.lstm = torch.nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True,
            dropout=0.1 if num_layers > 1 else 0
        )
        self.linear = torch.nn.Linear(hidden_dim * 2, hidden_dim)
        self.dropout = torch.nn.Dropout(0.1)

    def forward(self, x):
        # x: (B, T, input_dim)
        x, _ = self.lstm(x)  # (B, T, 2*H)
        x = self.dropout(x)
        x = self.linear(x)   # (B, T, H)
        return x

class Decoder(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256):
        super().__init__()
        self.embed = torch.nn.Embedding(vocab_size, embed_dim)
        self.lstm = torch.nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.dropout = torch.nn.Dropout(0.1)

    def forward(self, y):
        # y: (B, U)
        y = self.embed(y)     # (B, U, E)
        y, _ = self.lstm(y)   # (B, U, H)
        y = self.dropout(y)
        return y

class JointNetwork(torch.nn.Module):
    def __init__(self, hidden_dim, vocab_size):
        super().__init__()
        # Use a more robust approach with a multi-layer network
        self.joint = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim * 2, hidden_dim),
            torch.nn.Tanh(),
            torch.nn.Linear(hidden_dim, vocab_size)
        )

    def forward(self, h_enc, h_dec):
        # h_enc: (B, T, H), h_dec: (B, U, H)

        # Get batch size and sequence lengths
        batch_size = h_enc.size(0)
        T = h_enc.size(1)  # Audio sequence length
        U = h_dec.size(1)  # Text sequence length

        # Expand dimensions for broadcasting
        h_enc = h_enc.unsqueeze(2)  # (B, T, 1, H)
        h_dec = h_dec.unsqueeze(1)  # (B, 1, U, H)

        # Expand to create the full cartesian product for alignment
        h_enc = h_enc.expand(-1, -1, U, -1)  # (B, T, U, H)
        h_dec = h_dec.expand(-1, T, -1, -1)  # (B, T, U, H)

        # Concatenate features
        joint = torch.cat([h_enc, h_dec], dim=-1)  # (B, T, U, 2H)

        # Apply joint network to get logits
        logits = self.joint(joint)  # (B, T, U, V)

        return logits

class RNNTransducer(torch.nn.Module):
    def __init__(self, vocab_size, encoder_dim=256, decoder_dim=256):
        super().__init__()
        self.encoder = Encoder(hidden_dim=encoder_dim)
        self.decoder = Decoder(vocab_size, hidden_dim=decoder_dim)
        self.joint = JointNetwork(decoder_dim, vocab_size)
        self.vocab_size = vocab_size

    def forward(self, x, y):
        # x: (B, T, features), y: (B, U)
        h_enc = self.encoder(x)  # (B, T, H)
        h_dec = self.decoder(y)  # (B, U, H)

        # Both should now have the same hidden dimension
        logits = self.joint(h_enc, h_dec)  # (B, T, U, V)
        return logits



# First, let's prepare the dataset class
class AudioTextDataset(torch.utils.data.Dataset):
    def __init__(self, audio_files, texts, tokenizer, max_text_length=100,
                 target_sr=16000, max_duration=3):
        """
        Dataset class that handles:
        - File paths for batch audio processing
        - Text sequences with tokenization
        - Automatic audio preprocessing in batches
        """
        self.audio_files = [str(f) for f in audio_files]  # Ensure string paths
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_text_length = max_text_length
        self.target_sr = target_sr
        self.max_duration = max_duration

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        """
        Returns single item with:
        - audio_file: Path for batch processing
        - text: Raw text string for tokenization in collate
        """
        return {
            'audio_file': self.audio_files[idx],
            'text': self.texts[idx],
            'tokenizer': self.tokenizer,  # Pass tokenizer for collate_fn
            'max_text_length': self.max_text_length
        }

def collate_fn(batch, device):
    """
    Custom collate function that:
    1. Processes audio files in batch using preprocess_for_rnnt_torchaudio
    2. Tokenizes and pads text sequences
    3. Handles audio length preservation
    """
    # Extract batch components
    audio_files = [item['audio_file'] for item in batch]
    texts = [item['text'] for item in batch]

    # Batch process audio files
    audio_features, audio_lengths = preprocess_for_rnnt_torchaudio(
        audio_files,
        target_sr=batch[0].get('target_sr', 16000),
        max_duration=batch[0].get('max_duration', 5),
        device=device
    )  # (B, T, 80)

    target_lengths = []
    for text in texts:
        encoding = batch[0]['tokenizer'].encode(text)
        target_lengths.append(len(encoding.ids))

    # Tokenize and pad texts
    text_sequences = []
    for item in batch:
        encoding = item['tokenizer'].encode(item['text'])
        text_ids = encoding.ids

        # Add special tokens
        text_ids = [
            item['tokenizer'].token_to_id("<bos>")
        ] + text_ids + [
            item['tokenizer'].token_to_id("<eos>")
        ]

        # Pad or truncate
        max_len = item['max_text_length']
        if len(text_ids) < max_len:
            text_ids = text_ids + [item['tokenizer'].token_to_id("<pad>")] * (max_len - len(text_ids))
        else:
            text_ids = text_ids[:max_len]

        text_sequences.append(torch.tensor(text_ids, dtype=torch.long))

    target_lengths_tensor = torch.tensor(target_lengths)
    # Stack text sequences
    text_tensor = torch.stack(text_sequences)

    return {
        'audio': audio_features.to(device),  # (B, T, 80)
        'text': text_tensor.to(device),     # (B, U_max)
        'target_lengths': target_lengths_tensor.to(device),
        'audio_lengths': audio_lengths # (B,)
    }

In [30]:
def rnnt_loss(logits, targets, input_lengths, target_lengths, blank):
    """Compute RNN-T loss with proper dimensionality handling"""

    # Ensure tensors are contiguous in memory
    logits = logits.contiguous()

    # Convert logits to log probabilities
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    # Ensure correct dtypes
    targets = targets.contiguous().to(torch.int32)
    input_lengths = input_lengths.to(torch.int32)
    target_lengths = target_lengths.to(torch.int32)

    # Compute loss
    loss = torchaudio.functional.rnnt_loss(
        log_probs,
        targets,
        input_lengths,
        target_lengths,
        blank=blank,
        reduction='mean'
    )
    return loss

def train_epoch(model, dataloader, optimizer, device, tokenizer):
    model.train()
    total_loss = 0

    for batch_idx, batch in enumerate(dataloader):
        # Get batch data
        audio = batch['audio']  # (B, T, 80)
        text = batch['text']    # (B, U)

        # Get audio lengths - must match the time dimension of the encoder output
        input_lengths = batch['audio_lengths']
        target_lengths = batch['target_lengths']
        blank_id = tokenizer.token_to_id("<blank>")

        # Ensure input lengths don't exceed the encoder output length
        encoder_output = model.encoder(audio)
        # Forward pass
        optimizer.zero_grad()

        # Pass input sequence and labels (excluding EOS for the decoder input)
        # Use encoder_output directly since we already computed it
        h_dec = model.decoder(text[:, :-1])  # (B, U-1, H)
        logits = model.joint(encoder_output, h_dec)  # (B, T, U, V)

        # For targets, use labels shifted right (exclude BOS)
        targets = text[:, 1:-1]  # Remove BOS and EOS for RNN-T targets

        # Check for dimension mismatch
        # RNNT expects the third dimension of logits to be target_length+1
        required_u_dim = targets.size(1) + 1
        if logits.size(2) != required_u_dim:
            print(f"WARNING: Joint output shape mismatch. Got {logits.size(2)} but need {required_u_dim} for dim=2")

        # Validate shapes before computing loss
        if logits.size(0) != targets.size(0) or logits.size(2) < targets.size(1) + 1:
            print(f"Shape mismatch: logits {logits.shape}, targets {targets.shape}")
            print(f"Required U+1 dimension: {targets.size(1) + 1}")
            print(f"Skipping this batch {batch_idx}")
            continue

        # Compute loss
        try:
            loss = rnnt_loss(
                logits=logits,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                blank=blank_id
            )

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            total_loss += loss.item()

        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            print(f"Shapes: logits {logits.shape}, targets {targets.shape}")
            print(f"Lengths: input {input_lengths}, target {target_lengths}")
            continue

    return total_loss / max(1, len(dataloader))

In [31]:
tokenizer = Tokenizer.from_file("cv_tokenizer.json")

In [32]:
# Create datasets
train_dataset = AudioTextDataset(
    audio_files=[os.path.join(audio_directory, f) for f in train_audio_files],
    texts=train_texts,
    tokenizer=tokenizer,
)

In [33]:
test_dataset = AudioTextDataset(
    audio_files=test_audio_files,
    texts=test_texts,
    tokenizer=tokenizer,
)

In [34]:
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=True,
        collate_fn=lambda batch: collate_fn(batch, device=device)
    )

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=lambda batch: collate_fn(batch, device=device)
)

In [35]:
def predict(model, audio_feature, tokenizer, device, max_decode_len=100):
    """
    Perform inference with the trained model
    """
    model.eval()

    # Add batch dimension
    audio_feature = audio_feature.unsqueeze(0).to(device)

    # Initialize with BOS token
    decoded = [tokenizer.token_to_id("<bos>")]

    with torch.no_grad():
        h_enc = model.encoder(audio_feature)  # (1, T, H)

        for _ in range(max_decode_len):
            # Get last predicted token
            last_token = torch.tensor([decoded[-1]], device=device).unsqueeze(0)

            # Decoder step
            h_dec = model.decoder(last_token)  # (1, 1, H)

            # Joint network
            logits = model.joint(h_enc, h_dec)  # (1, T, 1, V)
            log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

            # Sum over time
            scores = log_probs.sum(1).squeeze(1)  # (1, V)
            next_token = scores.argmax(-1).item()

            # Stop if EOS is predicted
            if next_token == tokenizer.token_to_id("<eos>"):
                break

            decoded.append(next_token)

    # Convert to text
    tokens = tokenizer.decode(decoded[1:])  # Skip BOS
    return tokens


In [36]:
# Initialize model
model = RNNTransducer(vocab_size=tokenizer.get_vocab_size()).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(1):
    print(f"Epoch {epoch+1}/5")

    # Train for one epoch
    train_loss = train_epoch(
        model=model,
        dataloader=train_loader,
        optimizer=optimizer,
        device=device,
        tokenizer=tokenizer
    )

    print(f"Train Loss: {train_loss:.4f}")

    # Save checkpoint
    # if args.save_dir:
    #     os.makedirs(args.save_dir, exist_ok=True)
    #     checkpoint_path = os.path.join(args.save_dir, f"rnnt_epoch_{epoch+1}.pt")
    #     torch.save({
    #         'epoch': epoch,
    #         'model_state_dict': model.state_dict(),
    #         'optimizer_state_dict': optimizer.state_dict(),
    #         'loss': train_loss,
    #     }, checkpoint_path)
    #     print(f"Saved checkpoint to {checkpoint_path}")

print("Training complete!")

Epoch 1/5
Error in batch 0: input length mismatch
Shapes: logits torch.Size([2, 501, 99, 6150]), targets torch.Size([2, 98])
Lengths: input tensor([382, 447], dtype=torch.int32), target tensor([10,  6])


KeyboardInterrupt: 