# DATA LOADING

In [2]:
import torch
import torchaudio
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch.nn as nn

# Load dataset
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "ur")

# Print basic dataset info
print(f"Training samples: {len(dataset['train'])}")
print(f"Validation samples: {len(dataset['validation'])}")
print(f"Test samples: {len(dataset['test'])}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/14.4k [00:00<?, ?B/s]

common_voice_11_0.py:   0%|          | 0.00/8.13k [00:00<?, ?B/s]

languages.py:   0%|          | 0.00/3.44k [00:00<?, ?B/s]

release_stats.py:   0%|          | 0.00/60.9k [00:00<?, ?B/s]

The repository for mozilla-foundation/common_voice_11_0 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/mozilla-foundation/common_voice_11_0.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


n_shards.json:   0%|          | 0.00/12.2k [00:00<?, ?B/s]

ur_train_0.tar:   0%|          | 0.00/111M [00:00<?, ?B/s]

ur_dev_0.tar:   0%|          | 0.00/84.7M [00:00<?, ?B/s]

ur_test_0.tar:   0%|          | 0.00/85.0M [00:00<?, ?B/s]

ur_other_0.tar:   0%|          | 0.00/993M [00:00<?, ?B/s]

ur_other_1.tar:   0%|          | 0.00/875M [00:00<?, ?B/s]

ur_other_2.tar:   0%|          | 0.00/130M [00:00<?, ?B/s]

ur_invalidated_0.tar:   0%|          | 0.00/91.9M [00:00<?, ?B/s]

train.tsv:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

dev.tsv:   0%|          | 0.00/818k [00:00<?, ?B/s]

test.tsv:   0%|          | 0.00/807k [00:00<?, ?B/s]

other.tsv:   0%|          | 0.00/21.2M [00:00<?, ?B/s]

invalidated.tsv:   0%|          | 0.00/859k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 4129it [00:00, 79900.90it/s]


Generating validation split: 0 examples [00:00, ? examples/s]


Reading metadata...: 3303it [00:00, 112832.39it/s]


Generating test split: 0 examples [00:00, ? examples/s]


Reading metadata...: 3302it [00:00, 75105.84it/s]


Generating other split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 10651it [00:00, 106501.98it/s][A
Reading metadata...: 23023it [00:00, 116623.60it/s][A
Reading metadata...: 34967it [00:00, 117905.64it/s][A
Reading metadata...: 46758it [00:00, 117477.18it/s][A
Reading metadata...: 58506it [00:00, 112241.71it/s][A
Reading metadata...: 69768it [00:00, 105804.89it/s][A
Reading metadata...: 85123it [00:00, 104502.13it/s]


Generating invalidated split: 0 examples [00:00, ? examples/s]


Reading metadata...: 3275it [00:00, 69566.97it/s]


Training samples: 4129
Validation samples: 3303
Test samples: 3302


# Pre-Processing

In [9]:

print("\nExample data structure:")
print(dataset['train'][0])

# 2. Create a vocabulary from the dataset
def create_vocabulary(dataset):
    vocab = set()
    for split in ['train', 'validation', 'test']:
        for item in dataset[split]:
            vocab.update(list(item['sentence']))
    return sorted(list(vocab))

vocabulary = create_vocabulary(dataset)
print(f"\nVocabulary size: {len(vocabulary)}")

# 3. Create character-to-index and index-to-character mappings
char_to_idx = {char: idx for idx, char in enumerate(vocabulary)}
idx_to_char = {idx: char for idx, char in enumerate(vocabulary)}

# 4. Set up the audio preprocessor
class AudioPreprocessor:
    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=400,
            hop_length=160,
            n_mels=80
        )

    def preprocess_audio(self, audio_path):
        try:
            # Load audio
            waveform, sr = torchaudio.load(audio_path)

            # Resample if necessary
            if sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
                waveform = resampler(waveform)

            # Convert to mono if stereo
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)

            # Convert to mel spectrogram
            mel_spec = self.mel_transform(waveform)

            return mel_spec
        except Exception as e:
            print(f"Error processing {audio_path}: {str(e)}")
            return None

# 5. Create the data processor
preprocessor = AudioPreprocessor()

# 6. Create a custom dataset class
class UrduSpeechDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_split, preprocessor, char_to_idx):
        self.dataset = dataset_split
        self.preprocessor = preprocessor
        self.char_to_idx = char_to_idx

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        # Process audio
        audio_features = self.preprocessor.preprocess_audio(item['audio']['path'])

        # Convert text to indices
        text_indices = [self.char_to_idx[char] for char in item['sentence']]

        return {
            'audio': audio_features,
            'text': torch.tensor(text_indices, dtype=torch.long),
            'text_length': len(text_indices)
        }

# 7. Create data loaders
def collate_fn(batch):
    # Filter out None values (failed preprocessing)
    batch = [item for item in batch if item['audio'] is not None]

    if len(batch) == 0:
        return None

    # Get max lengths
    max_audio_len = max(item['audio'].shape[2] for item in batch)
    max_text_len = max(item['text_length'] for item in batch)

    # Pad sequences
    audio_features = []
    text_features = []
    text_lengths = []

    for item in batch:
        # Pad audio
        audio = item['audio']
        audio_pad = torch.nn.functional.pad(
            audio,
            (0, max_audio_len - audio.shape[2])
        )
        audio_features.append(audio_pad)

        # Pad text
        text = item['text']
        text_pad = torch.nn.functional.pad(
            text,
            (0, max_text_len - len(text))
        )
        text_features.append(text_pad)
        text_lengths.append(item['text_length'])

    return {
        'audio': torch.stack(audio_features),
        'text': torch.stack(text_features),
        'text_length': torch.tensor(text_lengths)
    }

# 8. Create the data loaders
batch_size = 16
train_dataset = UrduSpeechDataset(dataset['train'], preprocessor, char_to_idx)
valid_dataset = UrduSpeechDataset(dataset['validation'], preprocessor, char_to_idx)
test_dataset = UrduSpeechDataset(dataset['test'], preprocessor, char_to_idx)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)

# 9. Test the data loader
print("\nTesting data loader...")
for batch in train_loader:
    if batch is not None:
        print("\nBatch shapes:")
        print(f"Audio: {batch['audio'].shape}")
        print(f"Text: {batch['text'].shape}")
        print(f"Text lengths: {batch['text_length'].shape}")
        break


Example data structure:
{'client_id': 'e53f84d151d6cc6d45a57decde08a99efe47d7751a4ca60e58fb87ea68a35d53dcae445c65d5e73e0449a0b1cf2b4d09f32874877e8786664aa50f1f2ec2b932', 'path': '/root/.cache/huggingface/datasets/downloads/extracted/5350814842baec1cce17a4cb70aed2f5d8243e8fe4e810ff027157f331f95972/ur_train_0/common_voice_ur_31771683.mp3', 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/5350814842baec1cce17a4cb70aed2f5d8243e8fe4e810ff027157f331f95972/ur_train_0/common_voice_ur_31771683.mp3', 'array': array([7.10542736e-14, 7.38964445e-13, 1.08002496e-12, ...,
       1.29391765e-06, 2.22157587e-06, 1.43777788e-06]), 'sampling_rate': 48000}, 'sentence': 'کبھی کبھار ہی خیالی پلاو بناتا ہوں', 'up_votes': 2, 'down_votes': 0, 'age': 'twenties', 'gender': 'male', 'accent': '', 'locale': 'ur', 'segment': ''}

Vocabulary size: 93

Testing data loader...

Batch shapes:
Audio: torch.Size([16, 1, 80, 803])
Text: torch.Size([16, 68])
Text lengths: torch.Size([16])


# Conformer-MoE

In [10]:
import os
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import wandb
from tqdm import tqdm
import numpy as np

# Memory optimization settings
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
torch.backends.cudnn.benchmark = True

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class MoELayer(nn.Module):
    def __init__(self, input_dim, num_experts=4, expert_dim=256):
        super().__init__()
        self.num_experts = num_experts
        self.gate = nn.Linear(input_dim, num_experts)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, expert_dim),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(expert_dim, input_dim)
            ) for _ in range(num_experts)
        ])

    def forward(self, x):
        original_shape = x.shape
        x_2d = x.view(-1, original_shape[-1])

        gates = F.softmax(self.gate(x_2d), dim=-1)
        expert_outputs = []
        for expert in self.experts:
            expert_output = expert(x_2d)
            expert_outputs.append(expert_output)
        expert_outputs = torch.stack(expert_outputs, dim=1)

        gates = gates.unsqueeze(-1)
        output = torch.sum(gates * expert_outputs, dim=1)
        output = output.view(original_shape)

        return output

class ConformerBlock(nn.Module):
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.ff1 = nn.Sequential(
            nn.Linear(dim, dim*2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(dim*2, dim)
        )

        self.attention = nn.MultiheadAttention(dim, num_heads, dropout=0.1)

        self.conv = nn.Sequential(
            nn.Conv1d(dim, dim*2, 1),
            nn.GLU(dim=1),
            nn.Conv1d(dim, dim, 3, padding=1, groups=dim),
            nn.BatchNorm1d(dim),
            nn.SiLU(),
            nn.Conv1d(dim, dim, 1)
        )

        self.moe = MoELayer(dim)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.norm4 = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x = self.ff1(x)
        x = residual + self.dropout(x)

        residual = x
        x = self.norm2(x)
        x_t = x.transpose(0, 1)
        x_t, _ = self.attention(x_t, x_t, x_t)
        x = x_t.transpose(0, 1)
        x = residual + self.dropout(x)

        residual = x
        x = self.norm3(x)
        x = x.transpose(1, 2)
        x = self.conv(x)
        x = x.transpose(1, 2)
        x = residual + self.dropout(x)

        residual = x
        x = self.norm4(x)
        x = self.moe(x)
        x = residual + self.dropout(x)

        return x

class ConformerMoE(nn.Module):
    def __init__(self, num_classes=93, input_dim=80, model_dim=256, num_layers=4):
        super().__init__()
        self.input_projection = nn.Linear(input_dim, model_dim)
        self.pos_encoding = PositionalEncoding(model_dim)
        self.conformer_layers = nn.ModuleList([
            ConformerBlock(model_dim) for _ in range(num_layers)
        ])
        self.output_projection = nn.Linear(model_dim, num_classes)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.squeeze(1).transpose(1, 2)
        x = self.input_projection(x)
        x = self.pos_encoding(x)

        for layer in self.conformer_layers:
            x = torch.utils.checkpoint.checkpoint(layer, x)

        x = self.output_projection(x)
        return x

class CTCLabelConverter:
    def __init__(self, char_to_idx):
        self.char_to_idx = char_to_idx
        self.idx_to_char = {v: k for k, v in char_to_idx.items()}

    def decode(self, pred, length):
        texts = []
        for p, l in zip(pred, length):
            text = ''.join([self.idx_to_char[i] for i in p[:l]])
            texts.append(text)
        return texts

def train_epoch(model, train_loader, criterion, optimizer, device, epoch, converter, grad_accum_steps=2):
    model.train()
    total_loss = 0
    num_batches = 0
    scaler = GradScaler()
    optimizer.zero_grad()

    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}')

    for batch_idx, batch in enumerate(progress_bar):
        if batch is None:
            continue

        audio = batch['audio'].to(device)
        text = batch['text'].to(device)
        text_lengths = batch['text_length'].to(device)

        with autocast():
            logits = model(audio)
            log_probs = F.log_softmax(logits, dim=-1)
            log_probs = log_probs.transpose(0, 1)

            input_lengths = torch.full(size=(audio.size(0),),
                                     fill_value=log_probs.size(0),
                                     dtype=torch.long).to(device)

            loss = criterion(log_probs, text, input_lengths, text_lengths) / grad_accum_steps

        scaler.scale(loss).backward()

        if (batch_idx + 1) % grad_accum_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        total_loss += loss.item() * grad_accum_steps
        num_batches += 1

        progress_bar.set_postfix({'loss': f'{loss.item() * grad_accum_steps:.4f}'})

        del loss, logits, log_probs
        torch.cuda.empty_cache()

    return total_loss / num_batches

def validate(model, val_loader, criterion, device, converter):
    model.eval()
    total_loss = 0
    num_batches = 0
    all_wer = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validation'):
            if batch is None:
                continue

            audio = batch['audio'].to(device)
            text = batch['text'].to(device)
            text_lengths = batch['text_length'].to(device)

            with autocast():
                logits = model(audio)
                log_probs = F.log_softmax(logits, dim=-1)
                log_probs = log_probs.transpose(0, 1)

                input_lengths = torch.full(size=(audio.size(0),),
                                         fill_value=log_probs.size(0),
                                         dtype=torch.long).to(device)

                loss = criterion(log_probs, text, input_lengths, text_lengths)

            pred = log_probs.argmax(dim=-1).transpose(0, 1)
            pred_texts = converter.decode(pred.cpu().numpy(), input_lengths.cpu().numpy())
            true_texts = converter.decode(text.cpu().numpy(), text_lengths.cpu().numpy())

            for pred_text, true_text in zip(pred_texts, true_texts):
                wer = calculate_wer(true_text, pred_text)
                all_wer.append(wer)

            total_loss += loss.item()
            num_batches += 1

            del loss, logits, log_probs
            torch.cuda.empty_cache()

    return total_loss / num_batches, sum(all_wer) / len(all_wer)

def calculate_wer(reference, hypothesis):
    ref_words = reference.split()
    hyp_words = hypothesis.split()

    d = np.zeros((len(ref_words) + 1, len(hyp_words) + 1))

    for i in range(len(ref_words) + 1):
        d[i, 0] = i
    for j in range(len(hyp_words) + 1):
        d[0, j] = j

    for i in range(1, len(ref_words) + 1):
        for j in range(1, len(hyp_words) + 1):
            if ref_words[i-1] == hyp_words[j-1]:
                d[i, j] = d[i-1, j-1]
            else:
                substitution = d[i-1, j-1] + 1
                insertion = d[i, j-1] + 1
                deletion = d[i-1, j] + 1
                d[i, j] = min(substitution, insertion, deletion)

    return d[len(ref_words)][len(hyp_words)] / len(ref_words)

def main():
    # Initialize wandb
    wandb.init(project="urdu-speech-recognition", name="conformer-moe-optimized")

    # Model parameters
    batch_size = 8
    num_epochs = 50
    grad_accum_steps = 2

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )

    valid_loader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

    # Initialize model and training components
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ConformerMoE()
    model = model.to(device)
    model = torch.compile(model)

    converter = CTCLabelConverter(char_to_idx)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )

    # Training loop
    best_val_loss = float('inf')
    patience = 5
    patience_counter = 0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        train_loss = train_epoch(model, train_loader, criterion, optimizer,
                               device, epoch, converter, grad_accum_steps)

        val_loss, val_wer = validate(model, valid_loader, criterion, device, converter)

        wandb.log({
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_wer': val_wer,
            'epoch': epoch
        })

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(f"Val WER: {val_wer:.4f}")

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0

            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'val_wer': val_wer
            }
            torch.save(checkpoint, 'checkpoints/best_model.pt')
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs!")
            break

        if (epoch + 1) % 5 == 0:
            torch.save(checkpoint, f'checkpoints/checkpoint_epoch_{epoch+1}.pt')

def evaluate():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ConformerMoE()
    checkpoint = torch.load('checkpoints/best_model.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()

    converter = CTCLabelConverter(char_to_idx)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)

    test_loader = DataLoader(
        test_dataset,
        batch_size=8,
        shuffle=False,
        collate_fn=collate_fn
    )

    test_loss, test_wer = validate(model, test_loader, criterion, device, converter)
    print(f"\nTest Results:")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test WER: {test_wer:.4f}")

    # Generate sample predictions
    with torch.no_grad():
        for batch in test_loader:
            if batch is None:
                continue

            audio = batch['audio'].to(device)
            text = batch['text']
            text_lengths = batch['text_length']

            with autocast():
                logits = model(audio)
                log_probs = F.log_softmax(logits, dim=-1)

# Training & Evaluation

In [15]:
from torch.amp import autocast
class CTCLabelConverter:
    def __init__(self, char_to_idx):
        self.char_to_idx = char_to_idx
        self.idx_to_char = {v: k for k, v in char_to_idx.items()}

    def decode(self, pred, length):
        texts = []
        for p, l in zip(pred, length):
            text = ''.join([self.idx_to_char[i] for i in p[:l]])
            texts.append(text)
        return texts

def train_epoch(model, train_loader, criterion, optimizer, device, epoch, converter, grad_accum_steps=2):
    model.train()
    total_loss = 0
    num_batches = 0
    scaler = GradScaler()
    optimizer.zero_grad()

    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}')

    for batch_idx, batch in enumerate(progress_bar):
        if batch is None:
            continue

        audio = batch['audio'].to(device)
        text = batch['text'].to(device)
        text_lengths = batch['text_length'].to(device)

        # Update autocast context
        with autocast(device_type='cuda', dtype=torch.float16):
            logits = model(audio)
            log_probs = F.log_softmax(logits, dim=-1)
            log_probs = log_probs.transpose(0, 1)

            input_lengths = torch.full(size=(audio.size(0),),
                                     fill_value=log_probs.size(0),
                                     dtype=torch.long).to(device)

            loss = criterion(log_probs, text, input_lengths, text_lengths) / grad_accum_steps

        scaler.scale(loss).backward()

        if (batch_idx + 1) % grad_accum_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        total_loss += loss.item() * grad_accum_steps
        num_batches += 1

        progress_bar.set_postfix({'loss': f'{loss.item() * grad_accum_steps:.4f}'})

        del loss, logits, log_probs
        torch.cuda.empty_cache()

    return total_loss / num_batches

def validate(model, val_loader, criterion, device, converter):
    model.eval()
    total_loss = 0
    num_batches = 0
    all_wer = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validation'):
            if batch is None:
                continue

            audio = batch['audio'].to(device)
            text = batch['text'].to(device)
            text_lengths = batch['text_length'].to(device)

            # Updated autocast call
            with autocast(device_type='cuda', dtype=torch.float16):
                logits = model(audio)
                log_probs = F.log_softmax(logits, dim=-1)
                log_probs = log_probs.transpose(0, 1)

                input_lengths = torch.full(size=(audio.size(0),),
                                         fill_value=log_probs.size(0),
                                         dtype=torch.long).to(device)

                loss = criterion(log_probs, text, input_lengths, text_lengths)

            pred = log_probs.argmax(dim=-1).transpose(0, 1)
            pred_texts = converter.decode(pred.cpu().numpy(), input_lengths.cpu().numpy())
            true_texts = converter.decode(text.cpu().numpy(), text_lengths.cpu().numpy())

            for pred_text, true_text in zip(pred_texts, true_texts):
                wer = calculate_wer(true_text, pred_text)
                all_wer.append(wer)

            total_loss += loss.item()
            num_batches += 1

            del loss, logits, log_probs
            torch.cuda.empty_cache()

    return total_loss / num_batches, sum(all_wer) / len(all_wer)

def calculate_wer(reference, hypothesis):
    ref_words = reference.split()
    hyp_words = hypothesis.split()

    d = np.zeros((len(ref_words) + 1, len(hyp_words) + 1))

    for i in range(len(ref_words) + 1):
        d[i, 0] = i
    for j in range(len(hyp_words) + 1):
        d[0, j] = j

    for i in range(1, len(ref_words) + 1):
        for j in range(1, len(hyp_words) + 1):
            if ref_words[i-1] == hyp_words[j-1]:
                d[i, j] = d[i-1, j-1]
            else:
                substitution = d[i-1, j-1] + 1
                insertion = d[i, j-1] + 1
                deletion = d[i-1, j] + 1
                d[i, j] = min(substitution, insertion, deletion)

    return d[len(ref_words)][len(hyp_words)] / len(ref_words)

def main():
    # Initialize wandb
    wandb.init(project="urdu-speech-recognition", name="conformer-moe-optimized")

    # Model parameters
    batch_size = 8
    num_epochs = 10
    grad_accum_steps = 2

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )

    valid_loader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

    # Initialize model and training components
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ConformerMoE()
    model = model.to(device)
    # Remove torch.compile() call

    converter = CTCLabelConverter(char_to_idx)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )

    # Training loop
    best_val_loss = float('inf')
    patience = 5
    patience_counter = 0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        train_loss = train_epoch(model, train_loader, criterion, optimizer,
                               device, epoch, converter, grad_accum_steps)

        val_loss, val_wer = validate(model, valid_loader, criterion, device, converter)

        wandb.log({
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_wer': val_wer,
            'epoch': epoch
        })

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(f"Val WER: {val_wer:.4f}")

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0

            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'val_wer': val_wer
            }
            torch.save(checkpoint, 'checkpoints/best_model.pt')
            print("Saved new best model!")
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs!")
            break

        if (epoch + 1) % 5 == 0:
            torch.save(checkpoint, f'checkpoints/checkpoint_epoch_{epoch+1}.pt')

def evaluate():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ConformerMoE()
    checkpoint = torch.load('checkpoints/best_model.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()

    converter = CTCLabelConverter(char_to_idx)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)

    test_loader = DataLoader(
        test_dataset,
        batch_size=8,
        shuffle=False,
        collate_fn=collate_fn
    )

    test_loss, test_wer = validate(model, test_loader, criterion, device, converter)
    print(f"\nTest Results:")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test WER: {test_wer:.4f}")

    # Generate sample predictions
    with torch.no_grad():
        for batch in test_loader:
            if batch is None:
                continue

            audio = batch['audio'].to(device)
            text = batch['text']
            text_lengths = batch['text_length']

            with autocast(device_type='cuda', dtype=torch.float16):
                logits = model(audio)
                log_probs = F.log_softmax(logits, dim=-1)
                pred = log_probs.argmax(dim=-1)
                input_lengths = torch.full((pred.size(0),), pred.size(1), dtype=torch.long)

                pred_texts = converter.decode(pred.cpu().numpy(), input_lengths.numpy())
                true_texts = converter.decode(text.cpu().numpy(), text_lengths.numpy())

                print("\nSample Predictions:")
                for i in range(min(5, len(pred_texts))):
                    print(f"\nPredicted: {pred_texts[i]}")
                    print(f"True: {true_texts[i]}")
                    print(f"WER: {calculate_wer(true_texts[i], pred_texts[i]):.4f}")

            break

if __name__ == "__main__":
    # Create directories
    os.makedirs('checkpoints', exist_ok=True)

    # Set random seeds for reproducibility
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    np.random.seed(42)

    # Training
    print("Starting training...")
    main()

    # Evaluation
    print("\nStarting evaluation...")
    evaluate()

Starting training...


VBox(children=(Label(value='0.012 MB of 0.012 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

  scaler = GradScaler()



Epoch 1/10


Epoch 0: 100%|██████████| 517/517 [02:19<00:00,  3.71it/s, loss=1.6410]
Validation: 100%|██████████| 413/413 [01:21<00:00,  5.09it/s]


Train Loss: 10.5313
Val Loss: 1.6217
Val WER: 1.0116
Saved new best model!

Epoch 2/10


Epoch 1: 100%|██████████| 517/517 [02:13<00:00,  3.86it/s, loss=1.9806]
Validation: 100%|██████████| 413/413 [01:15<00:00,  5.50it/s]


Train Loss: 1.4235
Val Loss: 1.4725
Val WER: 1.0315
Saved new best model!

Epoch 3/10


Epoch 2: 100%|██████████| 517/517 [02:23<00:00,  3.59it/s, loss=1.4560]
Validation: 100%|██████████| 413/413 [01:14<00:00,  5.52it/s]


Train Loss: 1.3254
Val Loss: 1.4021
Val WER: 1.0026
Saved new best model!

Epoch 4/10


Epoch 3: 100%|██████████| 517/517 [02:08<00:00,  4.01it/s, loss=1.2483]
Validation: 100%|██████████| 413/413 [01:06<00:00,  6.23it/s]


Train Loss: 1.2821
Val Loss: 1.4003
Val WER: 1.0787
Saved new best model!

Epoch 5/10


Epoch 4: 100%|██████████| 517/517 [02:07<00:00,  4.04it/s, loss=1.8298]
Validation: 100%|██████████| 413/413 [01:09<00:00,  5.92it/s]


Train Loss: 1.2541
Val Loss: 1.3493
Val WER: 1.1267
Saved new best model!

Epoch 6/10


Epoch 5: 100%|██████████| 517/517 [02:07<00:00,  4.06it/s, loss=1.4048]
Validation: 100%|██████████| 413/413 [01:07<00:00,  6.11it/s]


Train Loss: 1.2318
Val Loss: 1.3431
Val WER: 1.0879
Saved new best model!

Epoch 7/10


Epoch 6: 100%|██████████| 517/517 [02:05<00:00,  4.12it/s, loss=1.9832]
Validation: 100%|██████████| 413/413 [01:06<00:00,  6.17it/s]


Train Loss: 1.2119
Val Loss: 1.3356
Val WER: 1.1132
Saved new best model!

Epoch 8/10


Epoch 7: 100%|██████████| 517/517 [02:06<00:00,  4.09it/s, loss=1.5686]
Validation: 100%|██████████| 413/413 [01:12<00:00,  5.73it/s]


Train Loss: 1.1935
Val Loss: 1.3224
Val WER: 1.1282
Saved new best model!

Epoch 9/10


Epoch 8: 100%|██████████| 517/517 [02:09<00:00,  3.99it/s, loss=1.8812]
Validation: 100%|██████████| 413/413 [01:10<00:00,  5.89it/s]


Train Loss: 1.1780
Val Loss: 1.3123
Val WER: 1.1343
Saved new best model!

Epoch 10/10


Epoch 9: 100%|██████████| 517/517 [02:08<00:00,  4.02it/s, loss=1.5389]
Validation: 100%|██████████| 413/413 [01:09<00:00,  5.96it/s]


Train Loss: 1.1619
Val Loss: 1.3224
Val WER: 1.0588

Starting evaluation...


  checkpoint = torch.load('checkpoints/best_model.pt')
Validation: 100%|██████████| 413/413 [01:11<00:00,  5.75it/s]



Test Results:
Test Loss: 1.3508
Test WER: 1.1624

Sample Predictions:

Predicted:  اہ                              ؑؑ                                                                                                                                                                                                                                                                                                                                                                                                                                         ک                                                                                                                                                                                                                                              ہ ی                                                                                  ہ ے                                                                                                                        ہ ے       

In [5]:
import json

conformer_results = {
    'wer': 1.1624,  # Test WER from your results
    'loss': 1.3508,  # Test Loss from your results
    'predictions': [
        "اہ ک ہ ی ہ ے ہ ے ۔",
        "اہ ک ک ک ہ ہ ے ہ ے ۔",
        "او ک ک ا ک ک ک کی کی کی ک ک کی ہیے کیں ک ے ےے",
        "اس کی ک ک ک ک ک ہ ے ہ ے ۔",
        "او ک ک ک ک ہ ے ہ ے ۔"
    ],
    'references': [
        "یہی تناسب یوتھ کا بھی ہے۔",
        "اب اس کا حال تو یہی ہے کہ دعا کریں",
        "سپریم کورٹ میں ڈپٹی سپیکر قومی اسمبلی کی رولنگ پر از خود نوٹس کیس کی سماعت جاری ہے",
        "اس طرز عمل کا جمہوریت سے کیا واسطہ؟",
        "آئی ایم ایف کے ساتھ کن شرائط پر بات ہو رہی ہے؟"
    ]
}

# Save Conformer results
with open('conformer_results.json', 'w', encoding='utf-8') as f:
    json.dump(conformer_results, f, ensure_ascii=False, indent=2)

# WHisper Model

In [4]:
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import gc
import torchaudio

def resample_audio(waveform, orig_sr, target_sr=16000):
    """Resample audio to target sampling rate"""
    if orig_sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_sr, target_sr)
        return resampler(torch.from_numpy(waveform).float()).numpy()
    return waveform

def evaluate_whisper_simple():
    # Load dataset
    print("Loading dataset...")
    dataset = load_dataset("mozilla-foundation/common_voice_11_0", "ur")
    test_dataset = dataset['test']

    # Load model and processor
    print("Loading Whisper model...")
    processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
    model = WhisperForConditionalGeneration.from_pretrained(
        "openai/whisper-large-v2",
        low_cpu_mem_usage=True,
        use_safetensors=True
    )

    # Move model to device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Enable eval mode
    model.eval()

    # Process in smaller batches
    batch_size = 1  # Process one at a time to avoid memory issues
    total_wer = 0
    count = 0
    predictions = []
    references = []

    print("Starting evaluation...")

    try:
        for i in tqdm(range(0, min(len(test_dataset), 100), batch_size)):  # Test first 100 samples
            # Get sample
            sample = test_dataset[i]

            # Resample audio to 16kHz
            audio_array = resample_audio(
                sample["audio"]["array"],
                sample["audio"]["sampling_rate"],
                target_sr=16000
            )

            try:
                # Process audio
                input_features = processor(
                    audio_array,
                    sampling_rate=16000,
                    return_tensors="pt"
                ).input_features.to(device)

                # Generate transcription
                with torch.no_grad():
                    generated_ids = model.generate(
                        input_features,
                        max_length=225,
                        language="ur",
                        task="transcribe"
                    )

                # Decode prediction
                transcription = processor.batch_decode(
                    generated_ids,
                    skip_special_tokens=True
                )[0]

                # Store results
                predictions.append(transcription)
                references.append(sample["sentence"])

                # Calculate WER
                wer = calculate_wer(sample["sentence"], transcription)
                total_wer += wer
                count += 1

                # Print progress
                if i % 10 == 0:
                    print(f"\nSample {i}:")
                    print(f"Reference: {sample['sentence']}")
                    print(f"Predicted: {transcription}")
                    print(f"WER: {wer:.4f}")

            except Exception as e:
                print(f"Error processing sample {i}: {str(e)}")
                continue

            # Clear memory
            del input_features, generated_ids
            torch.cuda.empty_cache()
            gc.collect()

    except Exception as e:
        print(f"Error occurred: {str(e)}")

    finally:
        # Save results even if interrupted
        average_wer = total_wer / count if count > 0 else 0
        results = {
            'whisper': {
                'wer': float(average_wer),
                'predictions': predictions,
                'references': references
            }
        }

        # Save results
        import json
        with open('whisper_results.json', 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)

        print(f"\nFinal Results:")
        print(f"Average WER: {average_wer:.4f}")
        print(f"Processed {count} samples")
        print("Results saved to whisper_results.json")

        return results

def calculate_wer(reference, hypothesis):
    """
    Calculate Word Error Rate (WER) between reference and hypothesis
    """
    ref_words = reference.split()
    hyp_words = hypothesis.split()

    # Create matrix
    d = np.zeros((len(ref_words) + 1, len(hyp_words) + 1))

    # Initialize first row and column
    for i in range(len(ref_words) + 1):
        d[i, 0] = i
    for j in range(len(hyp_words) + 1):
        d[0, j] = j

    # Compute WER
    for i in range(1, len(ref_words) + 1):
        for j in range(1, len(hyp_words) + 1):
            if ref_words[i-1] == hyp_words[j-1]:
                d[i, j] = d[i-1, j-1]
            else:
                substitution = d[i-1, j-1] + 1
                insertion = d[i, j-1] + 1
                deletion = d[i-1, j] + 1
                d[i, j] = min(substitution, insertion, deletion)

    return d[len(ref_words)][len(hyp_words)] / len(ref_words)

if __name__ == "__main__":
    print("Starting Whisper evaluation...")
    whisper_results = evaluate_whisper_simple()

Starting Whisper evaluation...
Loading dataset...
Loading Whisper model...
Starting evaluation...


  0%|          | 0/100 [00:00<?, ?it/s]You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, None], [2, 50359]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Sample 0:
Reference: یہی تناسب "یوتھ" کا بھی ہے۔
Predicted:  یہی تناسب یوت کا بھی ہے
WER: 0.3333


 10%|█         | 10/100 [00:20<02:30,  1.68s/it]


Sample 10:
Reference: وہ بے وفا ثابت ہوا
Predicted:  بھوک بے وفا ثابت ہوا
WER: 0.2000


 20%|██        | 20/100 [00:38<02:32,  1.90s/it]


Sample 20:
Reference: ، اپنے گریبان میں جھانکنے کی بجائے دوسروں کو برا بھلا کہا جائے
Predicted:  اپنے گرپان میں جاکنے کے بجائے دوسرے کو بُرا بلا کہا جائیں
WER: 0.6154


 30%|███       | 30/100 [01:03<03:16,  2.80s/it]


Sample 30:
Reference: سروج خان بالی وڈ
Predicted:  ملوچ ہونے والی بارڈ
WER: 1.0000


 40%|████      | 40/100 [01:20<01:51,  1.85s/it]


Sample 40:
Reference: پاکستان تو اس وقت غیر معمولی حالات سے گزر رہا ہے۔
Predicted:  پاکستان تو اس وقت غیر معمولی حالات سے گزر رہا ہے
WER: 0.0909


 50%|█████     | 50/100 [01:38<01:29,  1.79s/it]


Sample 50:
Reference: اگر عمران خان ٹھیک کر رہے ہیں۔
Predicted:  اگر عمران خان ٹھیک کر رہے ہیں
WER: 0.1429


 60%|██████    | 60/100 [01:55<01:03,  1.60s/it]


Sample 60:
Reference: جمائیکا
Predicted:  جیمائکہ
WER: 1.0000


 70%|███████   | 70/100 [02:11<00:50,  1.69s/it]


Sample 70:
Reference: وہاں جھڑپیں جاری ہوں
Predicted:  وہاں جھڑ پہنچ جاری ہوں
WER: 0.5000


 80%|████████  | 80/100 [02:28<00:29,  1.46s/it]


Sample 80:
Reference: کھارے پانی کو استعمال کے قابل بنانے والے پلانٹ کا افتتاح کیا ہے
Predicted:  خارے پانی کو استعمال کا قابل بنانے والا پلانڈ کا افتتاہ کیا ہے
WER: 0.3846


 90%|█████████ | 90/100 [02:45<00:17,  1.75s/it]


Sample 90:
Reference: اسے پرے کرنا چاہیے۔
Predicted:  اسے پرے کرنا چاہیے
WER: 0.2500


100%|██████████| 100/100 [03:02<00:00,  1.83s/it]


Final Results:
Average WER: 0.3241
Processed 100 samples
Results saved to whisper_results.json





# Comparision

In [19]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tabulate import tabulate
import numpy as np

def compare_results():
    # Load results
    try:
        # Load Whisper results
        with open('whisper_results.json', 'r', encoding='utf-8') as f:
            whisper_data = json.load(f)

        # Load Conformer results
        with open('conformer_results.json', 'r', encoding='utf-8') as f:
            conformer_data = json.load(f)

        # Create comparison table
        comparison_data = {
            'Metric': ['Average WER'],
            'Whisper': [f"{whisper_data['whisper']['wer']:.4f}"],
            'Conformer-MoE': [f"{conformer_data['wer']:.4f}"]
        }

        # Print comparison table
        print("\n=== Model Performance Comparison ===")
        df = pd.DataFrame(comparison_data)
        print(tabulate(df, headers='keys', tablefmt='fancy_grid', showindex=False))

        # Show sample predictions
        print("\n=== Sample Predictions ===")
        num_samples = min(5, len(whisper_data['whisper']['predictions']))

        for i in range(num_samples):
            print(f"\nExample {i+1}:")
            print(f"Reference:     {whisper_data['whisper']['references'][i]}")
            print(f"Whisper:       {whisper_data['whisper']['predictions'][i]}")
            print(f"Conformer-MoE: {conformer_data['predictions'][i]}")
            print("-" * 80)

        # Calculate per-sample WER for both models
        whisper_wers = []
        conformer_wers = []

        for i in range(min(len(whisper_data['whisper']['predictions']), len(conformer_data['predictions']))):
            ref = whisper_data['whisper']['references'][i]
            whisper_pred = whisper_data['whisper']['predictions'][i]
            conformer_pred = conformer_data['predictions'][i]

            whisper_wers.append(calculate_wer(ref, whisper_pred))
            conformer_wers.append(calculate_wer(ref, conformer_pred))

        # Create WER distribution plot
        plt.figure(figsize=(10, 6))
        plt.hist(whisper_wers, alpha=0.5, label='Whisper', bins=20)
        plt.hist(conformer_wers, alpha=0.5, label='Conformer-MoE', bins=20)
        plt.xlabel('Word Error Rate (WER)')
        plt.ylabel('Count')
        plt.title('Distribution of WER Scores')
        plt.legend()
        plt.savefig('wer_comparison.png')
        plt.close()

        # Calculate additional statistics
        stats = {
            'Whisper': {
                'Mean WER': np.mean(whisper_wers),
                'Median WER': np.median(whisper_wers),
                'Std WER': np.std(whisper_wers),
                'Min WER': np.min(whisper_wers),
                'Max WER': np.max(whisper_wers)
            },
            'Conformer-MoE': {
                'Mean WER': np.mean(conformer_wers),
                'Median WER': np.median(conformer_wers),
                'Std WER': np.std(conformer_wers),
                'Min WER': np.min(conformer_wers),
                'Max WER': np.max(conformer_wers)
            }
        }

        # Print detailed statistics
        print("\n=== Detailed Statistics ===")
        stats_df = pd.DataFrame({
            'Metric': ['Mean WER', 'Median WER', 'Std WER', 'Min WER', 'Max WER'],
            'Whisper': [
                f"{stats['Whisper']['Mean WER']:.4f}",
                f"{stats['Whisper']['Median WER']:.4f}",
                f"{stats['Whisper']['Std WER']:.4f}",
                f"{stats['Whisper']['Min WER']:.4f}",
                f"{stats['Whisper']['Max WER']:.4f}"
            ],
            'Conformer-MoE': [
                f"{stats['Conformer-MoE']['Mean WER']:.4f}",
                f"{stats['Conformer-MoE']['Median WER']:.4f}",
                f"{stats['Conformer-MoE']['Std WER']:.4f}",
                f"{stats['Conformer-MoE']['Min WER']:.4f}",
                f"{stats['Conformer-MoE']['Max WER']:.4f}"
            ]
        })
        print(tabulate(stats_df, headers='keys', tablefmt='fancy_grid', showindex=False))

        # Create length analysis
        print("\n=== Length Analysis ===")
        whisper_lengths = [len(pred.split()) for pred in whisper_data['whisper']['predictions']]
        conformer_lengths = [len(pred.split()) for pred in conformer_data['predictions']]
        reference_lengths = [len(ref.split()) for ref in whisper_data['whisper']['references']]

        length_df = pd.DataFrame({
            'Metric': ['Average Length', 'Length Std Dev'],
            'Reference': [f"{np.mean(reference_lengths):.2f}", f"{np.std(reference_lengths):.2f}"],
            'Whisper': [f"{np.mean(whisper_lengths):.2f}", f"{np.std(whisper_lengths):.2f}"],
            'Conformer-MoE': [f"{np.mean(conformer_lengths):.2f}", f"{np.std(conformer_lengths):.2f}"]
        })
        print(tabulate(length_df, headers='keys', tablefmt='fancy_grid', showindex=False))

        # Save all results
        all_results = {
            'overall_comparison': comparison_data,
            'detailed_stats': stats,
            'length_analysis': {
                'reference': {'mean': np.mean(reference_lengths), 'std': np.std(reference_lengths)},
                'whisper': {'mean': np.mean(whisper_lengths), 'std': np.std(whisper_lengths)},
                'conformer': {'mean': np.mean(conformer_lengths), 'std': np.std(conformer_lengths)}
            },
            'sample_predictions': [
                {
                    'reference': whisper_data['whisper']['references'][i],
                    'whisper_prediction': whisper_data['whisper']['predictions'][i],
                    'conformer_prediction': conformer_data['predictions'][i],
                    'whisper_wer': whisper_wers[i],
                    'conformer_wer': conformer_wers[i]
                }
                for i in range(min(len(whisper_wers), len(conformer_wers)))
            ]
        }

        with open('complete_comparison_results.json', 'w', encoding='utf-8') as f:
            json.dump(all_results, f, ensure_ascii=False, indent=2)

        print("\nResults saved to 'complete_comparison_results.json'")
        print("WER distribution plot saved as 'wer_comparison.png'")

    except Exception as e:
        print(f"Error during comparison: {str(e)}")

def calculate_wer(reference, hypothesis):
    """
    Calculate Word Error Rate (WER) between reference and hypothesis
    """
    ref_words = reference.split()
    hyp_words = hypothesis.split()

    d = np.zeros((len(ref_words) + 1, len(hyp_words) + 1))

    for i in range(len(ref_words) + 1):
        d[i, 0] = i
    for j in range(len(hyp_words) + 1):
        d[0, j] = j

    for i in range(1, len(ref_words) + 1):
        for j in range(1, len(hyp_words) + 1):
            if ref_words[i-1] == hyp_words[j-1]:
                d[i, j] = d[i-1, j-1]
            else:
                substitution = d[i-1, j-1] + 1
                insertion = d[i, j-1] + 1
                deletion = d[i-1, j] + 1
                d[i, j] = min(substitution, insertion, deletion)

    return d[len(ref_words)][len(hyp_words)] / len(ref_words)

if __name__ == "__main__":
    compare_results()


=== Model Performance Comparison ===
╒═════════════╤═══════════╤═════════════════╕
│ Metric      │   Whisper │   Conformer-MoE │
╞═════════════╪═══════════╪═════════════════╡
│ Average WER │    0.3241 │          1.1624 │
╘═════════════╧═══════════╧═════════════════╛

=== Sample Predictions ===

Example 1:
Reference:     یہی تناسب "یوتھ" کا بھی ہے۔
Whisper:        یہی تناسب یوت کا بھی ہے
Conformer-MoE: اہ ک ہ ی ہ ے ہ ے ۔
--------------------------------------------------------------------------------

Example 2:
Reference:     اب اس کا حال تو یہی ہے کہ دعا کریں
Whisper:        اب اس کا حل تو یہی ہے کہ دعا کریں
Conformer-MoE: اہ ک ک ک ہ ہ ے ہ ے ۔
--------------------------------------------------------------------------------

Example 3:
Reference:     سپریم کورٹ میں ڈپٹی سپیکر قومی اسمبلی کی رولنگ پر از خود نوٹس کیس کی سماعت جاری ہے
Whisper:        سپریم کورڈ میں ڈپٹی سپیکر قومی اسمبلی کی رولنگ پر اسخد نوٹس کیس کی سماعت جاری ہے۔
Conformer-MoE: او ک ک ا ک ک ک کی کی کی ک ک کی ہیے کیں ک ے