# Audio Transcription with a Simple ASR Model
This notebook trains and evaluates a basic **Audio-to-Text** (ASR) model using a small subset of the **LibriSpeech** dataset from HuggingFace.

## 1) Environment Setup

In [None]:

!pip install torch torchaudio transformers datasets matplotlib --quiet
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


## 2) Dataset Download (LibriSpeech Dummy via HuggingFace)

In [None]:
#!/usr/bin/env python3
"""
Download a small ASR dataset from HuggingFace and save locally
so train_asr.py can use it without re-downloading.
"""

import os
from datasets import load_dataset
import torchaudio

OUT_DIR = "./data/asr"
DATASET_NAME = "hf-internal-testing/librispeech_asr_dummy"  # change to "librispeech_asr" for full set
SPLITS = ["train", "validation"]

def save_wav(audio, sample_rate, path):
    torchaudio.save(path, audio, sample_rate)

def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    for split in SPLITS:
        split_dir = os.path.join(OUT_DIR, split)
        os.makedirs(split_dir, exist_ok=True)

        print(f"[INFO] Downloading {DATASET_NAME} split: {split}")
        dataset = load_dataset(DATASET_NAME, split=split)

        for i, item in enumerate(dataset):
            audio = item["audio"]["array"]
            sr = item["audio"]["sampling_rate"]
            text = item["text"]

            wav_path = os.path.join(split_dir, f"{i}.wav")
            txt_path = os.path.join(split_dir, f"{i}.txt")

            save_wav(torchaudio.tensor(audio).unsqueeze(0), sr, wav_path)
            with open(txt_path, "w", encoding="utf-8") as f:
                f.write(text)

        print(f"[INFO] Saved {len(dataset)} samples to {split_dir}")



## 3) Utilities

In [None]:
# import torch
# from torch.utils.data import DataLoader
# import torchaudio
# from torchaudio.datasets import SPEECHCOMMANDS
# from torchaudio.transforms import MelSpectrogram
# import os

# class FilteredSpeechCommands(SPEECHCOMMANDS):
#     def __init__(self, root, subset=None, allowed_labels=None, n_mels=64):
#         self.allowed_labels = allowed_labels
#         self.mel_spec = MelSpectrogram(sample_rate=16000, n_mels=n_mels)

#         dataset_path = os.path.join(root, "speechcommands")
#         download_flag = not os.path.exists(dataset_path)
#         super().__init__(root=dataset_path, download=download_flag, subset=subset)

#         if self.allowed_labels:
#             self._walker = [
#                 w for w in self._walker
#                 if os.path.basename(os.path.dirname(w)) in self.allowed_labels
#             ]

#     def __getitem__(self, n):
#         waveform, sample_rate, label, *_ = super().__getitem__(n)
#         mel = self.mel_spec(waveform).squeeze(0).transpose(0, 1)  # [time, n_mels]
#         return mel, label

# def collate_fn(batch, label_to_idx):
#     tensors, targets = [], []
#     for mel, label in batch:
#         tensors.append(mel)  # [time, n_mels]
#         targets.append(label_to_idx[label])

#     tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)  # [B, max_time, n_mels]
#     tensors = tensors.permute(0, 2, 1).unsqueeze(1)  # [B, 1, n_mels, max_time]
#     return tensors, torch.tensor(targets)

# def get_speechcommands_loaders(data_dir="./data", batch_size=8, allowed_labels=None, n_mels=64):
#     if allowed_labels is None:
#         allowed_labels = ["yes", "no", "up", "down"]

#     label_to_idx = {label: idx for idx, label in enumerate(allowed_labels)}

#     train_set = FilteredSpeechCommands(
#         root=data_dir, subset="training",
#         allowed_labels=allowed_labels, n_mels=n_mels
#     )
#     test_set = FilteredSpeechCommands(
#         root=data_dir, subset="testing",
#         allowed_labels=allowed_labels, n_mels=n_mels
#     )

#     train_loader = DataLoader(
#         train_set, batch_size=batch_size, shuffle=True,
#         collate_fn=lambda b: collate_fn(b, label_to_idx)
#     )
#     test_loader = DataLoader(
#         test_set, batch_size=batch_size, shuffle=False,
#         collate_fn=lambda b: collate_fn(b, label_to_idx)
#     )

#     return train_loader, test_loader, allowed_labels

import torch
from torch.utils.data import DataLoader
import torchaudio
from torchaudio.datasets import SPEECHCOMMANDS
from torchaudio.transforms import MelSpectrogram, FrequencyMasking, TimeMasking
import os

class FilteredSpeechCommands(SPEECHCOMMANDS):
    def __init__(self, root, subset=None, allowed_labels=None, n_mels=64, augment=False):
        self.allowed_labels = allowed_labels
        self.mel_spec = MelSpectrogram(sample_rate=16000, n_mels=n_mels)

        # Add augmentation transforms if enabled
        self.augment = augment
        if self.augment:
            self.freq_mask = FrequencyMasking(freq_mask_param=15)
            self.time_mask = TimeMasking(time_mask_param=35)

        dataset_path = os.path.join(root, "speechcommands")
        download_flag = not os.path.exists(dataset_path)
        super().__init__(root=dataset_path, download=download_flag, subset=subset)

        if self.allowed_labels:
            self._walker = [
                w for w in self._walker
                if os.path.basename(os.path.dirname(w)) in self.allowed_labels
            ]

    def __getitem__(self, n):
        waveform, sample_rate, label, *_ = super().__getitem__(n)
        mel = self.mel_spec(waveform).squeeze(0).transpose(0, 1)  # [time, n_mels]

        # Apply augmentation only on training set
        if self.augment:
            mel = mel.transpose(0, 1)  # [n_mels, time]
            mel = self.freq_mask(mel)
            mel = self.time_mask(mel)
            mel = mel.transpose(0, 1)  # [time, n_mels]

        return mel, label

def collate_fn(batch, label_to_idx):
    tensors, targets = [], []
    for mel, label in batch:
        tensors.append(mel)  # [time, n_mels]
        targets.append(label_to_idx[label])

    tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)  # [B, max_time, n_mels]
    tensors = tensors.permute(0, 2, 1).unsqueeze(1)  # [B, 1, n_mels, max_time]
    return tensors, torch.tensor(targets)

def get_speechcommands_loaders(data_dir="./data", batch_size=8, allowed_labels=None, n_mels=64):
    if allowed_labels is None:
        allowed_labels = ["yes", "no", "up", "down"]

    label_to_idx = {label: idx for idx, label in enumerate(allowed_labels)}

    # Training set with augmentation
    train_set = FilteredSpeechCommands(
        root=data_dir, subset="training",
        allowed_labels=allowed_labels, n_mels=n_mels, augment=True
    )
    # Test set without augmentation
    test_set = FilteredSpeechCommands(
        root=data_dir, subset="testing",
        allowed_labels=allowed_labels, n_mels=n_mels, augment=False
    )

    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True,
        collate_fn=lambda b: collate_fn(b, label_to_idx)
    )
    test_loader = DataLoader(
        test_set, batch_size=batch_size, shuffle=False,
        collate_fn=lambda b: collate_fn(b, label_to_idx)
    )

    return train_loader, test_loader, allowed_labels



## 4) Model Definition

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.pool = nn.MaxPool2d(2, 2)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)

        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        # Adaptive pooling to avoid size issues
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        # Input shape: [B, 1, Mel, Time]
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)


## 5) Training Pipeline

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import os
from utils import get_speechcommands_loaders
from models import SimpleCNN

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            outputs = model(X)
            loss = criterion(outputs, y)
            total_loss += loss.item() * X.size(0)
            _, predicted = outputs.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    return total_loss / total, correct / total

def main(args):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    labels = ["yes", "no", "up", "down"]
    trainloader, testloader, _ = get_speechcommands_loaders(
        args.data_dir, args.batch_size, labels, n_mels=64
    )

    model = SimpleCNN(num_classes=len(labels)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    os.makedirs(args.out_dir, exist_ok=True)

    best_acc = 0.0
    for epoch in range(args.epochs):
        model.train()
        running_loss, correct, total = 0, 0, 0

        for X, y in trainloader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * X.size(0)
            _, predicted = outputs.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()

        train_loss = running_loss / total
        train_acc = correct / total

        test_loss, test_acc = evaluate(model, testloader, criterion, device)

        print(f"Epoch {epoch+1}/{args.epochs} "
              f"| Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} "
              f"| Test Loss: {test_loss:.4f} Acc: {test_acc:.4f}")

        # Save checkpoint for current epoch
        ckpt_path = os.path.join(args.out_dir, f"asr_epoch_{epoch+1}.pt")
        torch.save({
            "model_state": model.state_dict(),
            "labels": labels,
            "model_type": "cnn"
        }, ckpt_path)

        # Save best model
        if test_acc > best_acc:
            best_acc = test_acc
            best_path = os.path.join(args.out_dir, "asr_best.pt")
            torch.save({
                "model_state": model.state_dict(),
                "labels": labels,
                "model_type": "cnn"
            }, best_path)
            print(f"[INFO] Best model updated with acc={best_acc:.4f}")



## 6) Inference Pipeline

In [None]:
#!/usr/bin/env python3
import argparse
import torch
import torchaudio
import torchaudio.transforms as T
from models import SimpleCNN  # your actual model

def main(args):
    # Load checkpoint
    ckpt = torch.load(args.ckpt, map_location="cpu", weights_only=True)
    labels = ckpt["labels"]
    num_classes = len(labels)

    # Create model and load weights
    model = SimpleCNN(num_classes)
    model.load_state_dict(ckpt["model_state"])

    model.eval()

    # Load and preprocess audio
    waveform, sr = torchaudio.load(args.audio)
    waveform = T.Resample(sr, 16000)(waveform)  # resample to 16k
    mel_spec = T.MelSpectrogram(sample_rate=16000, n_mels=64)(waveform)
    mel_spec = mel_spec.unsqueeze(0)  # add batch dim

    # Run inference
    with torch.no_grad():
        outputs = model(mel_spec)
        pred_idx = outputs.argmax(1).item()
        probs = torch.softmax(outputs, 1)[0].tolist()
        print(f"Predicted: {labels[pred_idx]} (Probs: {probs})")

