In [None]:
!pip install numpy pandas torch transformers wandb tqdm scikit-learn librosa

In [None]:
import os
import time
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperFeatureExtractor, WhisperProcessor, WhisperForConditionalGeneration
import wandb
from tqdm import tqdm
from jiwer import wer, cer

In [None]:
# Set up Google Drive mounting
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
# Hyperparameters
AUDIO_DIR = '/content/drive/Shareddrives/CS307-Thesis/Dataset/single-speaker/'
TSV_FILE = '/content/drive/Shareddrives/CS307-Thesis/Dataset/single-speaker/validated.tsv'
CHECKPOINT_DIR = '/content/drive/Shareddrives/CS307-Thesis/Dataset/whisper_checkpoints/'
MAX_SAMPLES = 1000
BATCH_SIZE = 4
NUM_WORKERS = 4
EPOCHS = 3
LEARNING_RATE = 1e-5
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Load data
audio_files, transcripts, languages, timestamps = load_data(TSV_FILE, AUDIO_DIR, max_samples=MAX_SAMPLES)

In [None]:
# Create dataset and dataloader
class WhisperDataset(Dataset):
    def __init__(self, audio_files, transcripts, languages, timestamps):
        self.audio_files = audio_files
        self.transcripts = transcripts
        self.languages = languages
        self.timestamps = timestamps
        self.processor = WhisperProcessor.from_pretrained("openai/whisper-small")

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_file = self.audio_files[idx]
        transcript = self.transcripts[idx]
        language = self.languages[idx]
        start_time, end_time = self.timestamps[idx]

        # Load audio
        audio = torchaudio.load(audio_file)[0].squeeze(0)
        # Crop audio based on timestamps
        audio = audio[int(start_time * 16000):int(end_time * 16000)]

        # Preprocess audio and text
        pixel_values = self.processor.feature_extractor(audio, sampling_rate=16000, return_tensors="pt").pixel_values
        input_ids = self.processor.tokenizer(transcript, return_tensors="pt").input_ids

        return {
            "audio": pixel_values,
            "input_ids": input_ids,
            "language": language
        }

In [None]:
dataset = WhisperDataset(audio_files, transcripts, languages, timestamps)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

In [None]:
# Initialize Weights & Biases
wandb.init(project="whisper-fine-tuning", config={
    "batch_size": BATCH_SIZE,
    "learning_rate": LEARNING_RATE,
    "epochs": EPOCHS,
    "max_samples": MAX_SAMPLES
})

In [None]:
# Load Whisper model and fine-tune
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [None]:
# Load checkpoint if available
start_epoch = 0
if os.path.exists(os.path.join(CHECKPOINT_DIR, "checkpoint.pt")):
    checkpoint = torch.load(os.path.join(CHECKPOINT_DIR, "checkpoint.pt"))
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    print(f"Loaded checkpoint from epoch {start_epoch}")

for epoch in range(start_epoch, EPOCHS):
    train_loss = 0
    train_wer, train_cer, train_acc, train_precision, train_recall, train_f1 = 0, 0, 0, 0, 0, 0
    model.train()
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}", unit="batch"):
        audio = batch["audio"].to(DEVICE)
        input_ids = batch["input_ids"].to(DEVICE)
        language = batch["language"]

        optimizer.zero_grad()
        output = model(audio, input_ids=input_ids, return_dict=True)
        loss = output.loss
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Evaluate metrics
        predicted_ids = output.logits.argmax(-1)
        predicted_text = [model.processor.tokenizer.decode(p, skip_special_tokens=True) for p in predicted_ids]
        true_text = [model.processor.tokenizer.decode(t, skip_special_tokens=True) for t in input_ids]
        train_wer += wer(true_text, predicted_text)
        train_cer += cer(true_text, predicted_text)
        train_acc += (np.array(predicted_text) == np.array(true_text)).mean()
        train_precision += (np.array(predicted_text) == np.array(true_text)).mean()
        train_recall += (np.array(predicted_text) == np.array(true_text)).mean()
        train_f1 += 2 * train_precision * train_recall / (train_precision + train_recall)

    train_loss /= len(dataloader)
    train_wer /= len(dataloader)
    train_cer /= len(dataloader)
    train_acc /= len(dataloader)
    train_precision /= len(dataloader)
    train_recall /= len(dataloader)
    train_f1 /= len(dataloader)

    # Log metrics to Weights & Biases
    wandb.log({
        "train_loss": train_loss,
        "train_wer": train_wer,
        "train_cer": train_cer,
        "train_accuracy": train_acc,
        "train_precision": train_precision,
        "train_recall": train_recall,
        "train_f1": train_f1
    })

    # Save checkpoint
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict()
    }
    torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, "checkpoint.pt"))
    print(f"Checkpoint saved for epoch {epoch+1}")

print("Training complete!")