In [27]:
# =====================================================
# 1. Imports & Setup
# =====================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchaudio
from datasets import load_dataset
import jiwer
import wandb
import numpy as np
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
# =====================================================
# 2. Config (for sweep compatibility)
# =====================================================
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--project", type=str, default="asr-rnn-sweep")
parser.add_argument("--accent", type=str, default="shona")
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--hidden_dim", type=int, default=256)
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--max_train_samples", type=int, default=200)
parser.add_argument("--max_val_samples", type=int, default=50)
args, _ = parser.parse_known_args()

wandb.init(project=args.project, config=vars(args))
config = wandb.config

In [30]:
# =====================================================
# 3. Load Real Dataset (AfriSpeech subset)
# =====================================================
dataset = load_dataset("tobiolatunji/afrispeech-200", config.accent)
train_set = dataset["train"].select(range(config.max_train_samples))
val_split = "dev" if "dev" in dataset else "test"
val_set = dataset[val_split].select(range(config.max_val_samples))

README.md: 0.00B [00:00, ?B/s]

afrispeech-200.py: 0.00B [00:00, ?B/s]

RuntimeError: Dataset scripts are no longer supported, but found afrispeech-200.py

In [None]:





# =====================================================
# 4. Feature Extraction & Tokenization
# =====================================================
from torchaudio.transforms import MelSpectrogram, Resample

sample_rate = 16000
n_mels = 80

mel_transform = MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels)

# Build simple tokenizer from training transcripts
chars = sorted(list(set("".join(train_set["sentence"]).lower())))
char2idx = {c: i+1 for i, c in enumerate(chars)}  # reserve 0 for blank
idx2char = {i: c for c, i in char2idx.items()}
vocab_size = len(char2idx) + 1

def encode_text(text):
    return torch.tensor([char2idx.get(c, 0) for c in text.lower() if c in char2idx], dtype=torch.long)

def collate_fn(batch):
    specs, labels, input_lengths, label_lengths = [], [], [], []
    for b in batch:
        wav = torch.tensor(b["audio"]["array"]).float()
        if b["audio"]["sampling_rate"] != sample_rate:
            wav = Resample(b["audio"]["sampling_rate"], sample_rate)(wav)
        mel = mel_transform(wav).transpose(0, 1)
        specs.append(mel)
        label = encode_text(b["sentence"])
        labels.append(label)
        input_lengths.append(mel.shape[0])
        label_lengths.append(len(label))
    specs = nn.utils.rnn.pad_sequence(specs, batch_first=True)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
    return specs, labels, torch.tensor(input_lengths), torch.tensor(label_lengths)

train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=1, shuffle=False, collate_fn=collate_fn)

# =====================================================
# 5. Basic RNN Model (CTC-compatible)
# =====================================================
class BasicRNNASR(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout):
        super().__init__()
        self.rnn = nn.RNN(
            input_dim, hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=True
        )
        self.fc = nn.Linear(hidden_dim * 2, output_dim)  # bidirectional doubles dim
    def forward(self, x):
        out, _ = self.rnn(x)
        return self.fc(out)

model = BasicRNNASR(n_mels, config.hidden_dim, vocab_size, config.num_layers, config.dropout).to(device)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=config.lr)

# =====================================================
# 6. Training & Validation Loops
# =====================================================
def train_one_epoch():
    model.train()
    total_loss = 0
    for specs, labels, input_lens, label_lens in tqdm(train_loader, desc="Train"):
        specs, labels = specs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(specs)  # (B, T, vocab)
        log_probs = nn.functional.log_softmax(logits, dim=-1)
        loss = criterion(log_probs.transpose(0, 1), labels, input_lens, label_lens)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        wandb.log({"train_batch_loss": loss.item()})
    return total_loss / len(train_loader)

def evaluate():
    model.eval()
    total_loss, wers = 0, []
    with torch.no_grad():
        for specs, labels, input_lens, label_lens in tqdm(val_loader, desc="Val"):
            specs, labels = specs.to(device), labels.to(device)
            logits = model(specs)
            log_probs = nn.functional.log_softmax(logits, dim=-1)
            loss = criterion(log_probs.transpose(0, 1), labels, input_lens, label_lens)
            total_loss += loss.item()

            # Decode predictions
            pred_ids = torch.argmax(log_probs, dim=-1).squeeze().detach().cpu().numpy()
            decoded = "".join([idx2char[i] for i in np.unique(pred_ids) if i in idx2char])
            true_text = "".join([idx2char[i.item()] for i in labels[0] if i.item() in idx2char])
            wers.append(jiwer.wer(true_text, decoded))
    return total_loss / len(val_loader), np.mean(wers)

for epoch in range(config.epochs):
    train_loss = train_one_epoch()
    val_loss, val_wer = evaluate()
    wandb.log({"epoch": epoch + 1, "train_loss": train_loss, "val_loss": val_loss, "val_wer": val_wer})
    print(f"Epoch {epoch+1}: Train Loss {train_loss:.4f}, Val Loss {val_loss:.4f}, Val WER {val_wer:.3f}")

wandb.finish()

# =====================================================
# 7. Sweep Config (Small)
# =====================================================
sweep_config = {
    "method": "random",
    "metric": {"name": "val_wer", "goal": "minimize"},
    "parameters": {
        "lr": {"values": [1e-3, 5e-4]},
        "hidden_dim": {"values": [128, 256]},
        "num_layers": {"values": [1, 2]},
        "dropout": {"values": [0.1, 0.3]},
        "batch_size": {"values": [4, 8]}
    }
}

# =====================================================
# 8. Sweep Launch (Manual Trigger)
# =====================================================
# Uncomment once ready:
# sweep_id = wandb.sweep(sweep_config, project=args.project)
# wandb.agent(sweep_id, function=lambda: None)  # placeholder; run manually in W&B dashboard