In [47]:
!pip install numpy pandas torch transformers wandb tqdm scikit-learn librosa jiwer torchaudio



In [48]:
import os
import time
import pandas as pd
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import WhisperFeatureExtractor, WhisperProcessor, WhisperForConditionalGeneration
import wandb
from tqdm import tqdm
from jiwer import wer, cer

In [49]:
# Set up Google Drive mounting
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [50]:
# Hyperparameters
AUDIO_DIR = '/content/gdrive/Shareddrives/CS307-Thesis/Dataset/single-speaker/'
TSV_FILE = '/content/gdrive/Shareddrives/CS307-Thesis/Dataset/single-speaker/validated.tsv'
CHECKPOINT_DIR = '/content/gdrive/Shareddrives/CS307-Thesis/Dataset/whisper_checkpoints/'
MAX_SAMPLES = 13
BATCH_SIZE = 4
NUM_WORKERS = 2
EPOCHS = 3
LEARNING_RATE = 1e-5
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [51]:
def load_data(tsv_file, audio_dir, max_samples=None):
    """
    Load data from TSV file with timestamp handling, compatible with both "sec" and "min:sec" formats.
    """
    audio_files, transcripts, languages, timestamps = [], [], [], []

    # Read TSV file
    df = pd.read_csv(tsv_file, sep='\t')
    required_columns = ['path', 'start_time', 'end_time', 'language', 'sentence']

    # Verify all required columns are present
    if not all(col in df.columns for col in required_columns):
        raise ValueError(f"TSV file must contain columns: {required_columns}")

    # Shuffle and limit samples if specified
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    if max_samples:
        df = df.head(max_samples)

    for _, row in df.iterrows():
        audio_file = row['path']
        if not audio_file.endswith((".mp3", ".wav", ".flac")):
            print(f"Skipping unsupported file type: {audio_file}")
            continue

        full_audio_path = os.path.join(audio_dir, audio_file)
        if not os.path.exists(full_audio_path):
            print(f"Warning: Audio file not found: {full_audio_path}")
            continue

        # Parse timestamps
        def parse_time(time_str):
            try:
                # Check if time is already in seconds
                return float(time_str)
            except ValueError:
                # Convert from "min:sec" format to seconds
                minutes, seconds = map(float, time_str.split(":"))
                return minutes * 60 + seconds

        try:
            start_time = parse_time(row['start_time'])
            end_time = parse_time(row['end_time'])
        except Exception as e:
            print(f"Error parsing timestamps for {audio_file}: {str(e)}")
            continue

        audio_files.append(full_audio_path)
        transcripts.append(row['sentence'])
        timestamps.append((start_time, end_time))
        languages.append(row['language'])

    return audio_files, transcripts, languages, timestamps

In [52]:
# Load data
audio_files, transcripts, languages, timestamps = load_data(TSV_FILE, AUDIO_DIR, max_samples=MAX_SAMPLES)

In [53]:
# Create dataset and dataloader
class WhisperDataset(Dataset):
    def __init__(self, audio_files, transcripts, languages, timestamps):
        self.audio_files = audio_files
        self.transcripts = transcripts
        self.languages = languages
        self.timestamps = timestamps
        self.processor = WhisperProcessor.from_pretrained("openai/whisper-small")

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_file = self.audio_files[idx]
        transcript = self.transcripts[idx]
        language = self.languages[idx]
        start_time, end_time = self.timestamps[idx]

        # Load audio
        audio, sampling_rate = torchaudio.load(audio_file)

        # Crop audio based on timestamps
        start_frame = int(start_time * sampling_rate)
        end_frame = int(end_time * sampling_rate)
        audio = audio[:, start_frame:end_frame]

        # Resample audio to 16kHz if necessary
        if sampling_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
            audio = resampler(audio)

        # Pad the audio data to the expected shape
        expected_shape = (1, audio.shape[1])
        if audio.shape != expected_shape:
            audio = F.pad(audio, (0, expected_shape[1] - audio.shape[1]), mode='constant', value=0)

        # Preprocess audio and text
        pixel_values = self.processor.feature_extractor(audio, sampling_rate=16000, return_tensors="pt").pixel_values
        input_ids = self.processor.tokenizer(transcript, return_tensors="pt").input_ids

        return {
            "audio": pixel_values,
            "input_ids": input_ids,
            "language": language
        }

In [54]:
dataset = WhisperDataset(audio_files, transcripts, languages, timestamps)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

In [55]:
# Initialize Weights & Biases
wandb.init(project="whisper-fine-tuning", config={
    "batch_size": BATCH_SIZE,
    "learning_rate": LEARNING_RATE,
    "epochs": EPOCHS,
    "max_samples": MAX_SAMPLES
})

In [56]:
# Load Whisper model and fine-tune
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [57]:
# Load checkpoint if available
start_epoch = 0
if os.path.exists(os.path.join(CHECKPOINT_DIR, "checkpoint.pt")):
    checkpoint = torch.load(os.path.join(CHECKPOINT_DIR, "checkpoint.pt"))
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    print(f"Loaded checkpoint from epoch {start_epoch}")

for epoch in range(start_epoch, EPOCHS):
    train_loss = 0
    train_wer, train_cer, train_acc, train_precision, train_recall, train_f1 = 0, 0, 0, 0, 0, 0
    model.train()
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}", unit="batch"):
        audio = batch["audio"].to(DEVICE)
        input_ids = batch["input_ids"].to(DEVICE)
        language = batch["language"]

        optimizer.zero_grad()
        output = model(audio, input_ids=input_ids, return_dict=True)
        loss = output.loss
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Evaluate metrics
        predicted_ids = output.logits.argmax(-1)
        predicted_text = [model.processor.tokenizer.decode(p, skip_special_tokens=True) for p in predicted_ids]
        true_text = [model.processor.tokenizer.decode(t, skip_special_tokens=True) for t in input_ids]
        train_wer += wer(true_text, predicted_text)
        train_cer += cer(true_text, predicted_text)
        train_acc += (np.array(predicted_text) == np.array(true_text)).mean()
        train_precision += (np.array(predicted_text) == np.array(true_text)).mean()
        train_recall += (np.array(predicted_text) == np.array(true_text)).mean()
        train_f1 += 2 * train_precision * train_recall / (train_precision + train_recall)

    train_loss /= len(dataloader)
    train_wer /= len(dataloader)
    train_cer /= len(dataloader)
    train_acc /= len(dataloader)
    train_precision /= len(dataloader)
    train_recall /= len(dataloader)
    train_f1 /= len(dataloader)

    # Log metrics to Weights & Biases
    wandb.log({
        "train_loss": train_loss,
        "train_wer": train_wer,
        "train_cer": train_cer,
        "train_accuracy": train_acc,
        "train_precision": train_precision,
        "train_recall": train_recall,
        "train_f1": train_f1
    })

    # Save checkpoint
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict()
    }
    torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, "checkpoint.pt"))
    print(f"Checkpoint saved for epoch {epoch+1}")

print("Training complete!")

Epoch 1/3:   0%|          | 0/4 [00:00<?, ?batch/s]Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b72544f41f0><function _MultiProcessingDataLoaderIter.__del__ at 0x7b72544f41f0>

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
        if w.is_alive():self._shutdown_workers()

  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
        assert self._parent_pid == os.getpid(), 'can only test a child process'if w.is_alive():

AssertionEr

ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-53-57badd0b684f>", line 38, in __getitem__
    pixel_values = self.processor.feature_extractor(audio, sampling_rate=16000, return_tensors="pt").pixel_values
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/feature_extraction_whisper.py", line 282, in __call__
    padded_inputs = self.pad(
  File "/usr/local/lib/python3.10/dist-packages/transformers/feature_extraction_sequence_utils.py", line 210, in pad
    outputs = self._pad(
  File "/usr/local/lib/python3.10/dist-packages/transformers/feature_extraction_sequence_utils.py", line 282, in _pad
    processed_features[self.model_input_names[0]] = np.pad(
  File "/usr/local/lib/python3.10/dist-packages/numpy/lib/arraypad.py", line 748, in pad
    pad_width = _as_pairs(pad_width, array.ndim, as_index=True)
  File "/usr/local/lib/python3.10/dist-packages/numpy/lib/arraypad.py", line 522, in _as_pairs
    return np.broadcast_to(x, (ndim, 2)).tolist()
  File "/usr/local/lib/python3.10/dist-packages/numpy/lib/stride_tricks.py", line 413, in broadcast_to
    return _broadcast_to(array, shape, subok=subok, readonly=True)
  File "/usr/local/lib/python3.10/dist-packages/numpy/lib/stride_tricks.py", line 349, in _broadcast_to
    it = np.nditer(
ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (2,2)  and requested shape (3,2)
