In [None]:
!pip install wandb transformers torchaudio jiwer sklearn

In [None]:
import os
import wandb
import pandas as pd
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import WhisperProcessor, WhisperForConditionalGeneration, get_scheduler
from jiwer import wer, cer
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import matplotlib.pyplot as plt
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from google.colab import drive

In [None]:
# Mount Google Drive
drive.mount('/content/drive')

# Initialize WandB
wandb.init(project="taglish-whisper-fine-tuning")

# Configurations for flexible parameter adjustment
config = {
    "tsv_file": "/content/drive/MyDrive/path/to/train.tsv",
    "audio_dir": "/content/drive/MyDrive/path/to/audio_files",
    "batch_size": 4,
    "learning_rate": 5e-5,
    "weight_decay": 0.01,             # Added weight decay
    "warmup_steps": 500,               # Added warm-up steps
    "epochs": 3,
    "max_samples": 100,
    "checkpoint_interval": 2
}

In [None]:
# Load the processor and model for Whisper
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
model.resize_token_embeddings(len(processor.tokenizer))

In [None]:
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

In [None]:
# Optimizer setup
optimizer = AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

In [None]:
# Load dataset with timestamp support
def load_data(tsv_file, audio_dir, max_samples):
    audio_files, transcripts, timestamps = [], [], []
    df = pd.read_csv(tsv_file, sep='\t').sample(frac=1).reset_index(drop=True)
    count = 0

    for _, row in df.iterrows():
        audio_file = row['path']
        if not audio_file.endswith(".mp3"):
            audio_file += ".mp3"
        transcript = row['sentence']
        start_time, end_time = row.get("start", 0), row.get("end", None)

        audio_files.append(os.path.join(audio_dir, audio_file))
        transcripts.append(transcript)
        timestamps.append((start_time, end_time))
        count += 1

        if count >= max_samples:
            break
    return audio_files, transcripts, timestamps

In [None]:
# Custom Dataset with timestamp support
class ProcessData(Dataset):
    def __init__(self, audio_files, transcripts, timestamps, processor):
        self.audio_files = audio_files
        self.transcripts = transcripts
        self.timestamps = timestamps
        self.processor = processor

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        transcript = self.transcripts[idx]
        start_time, end_time = self.timestamps[idx]

        audio, sample_rate = torchaudio.load(audio_path)
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            audio = resampler(audio)

        # Use timestamps to trim audio
        if start_time or end_time:
            start_frame, end_frame = int(start_time * 16000), int(end_time * 16000) if end_time else -1
            audio = audio[:, start_frame:end_frame]

        input_features = self.processor(audio.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
        labels = self.processor(transcript, return_tensors="pt").input_ids

        return {"input_features": input_features.squeeze(), "labels": labels.squeeze()}

In [None]:
# Load the data
audio_files, transcripts, timestamps = load_data(config["tsv_file"], config["audio_dir"], config["max_samples"])
train_dataset = ProcessData(audio_files, transcripts, timestamps, processor)
train_dataloader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)

In [None]:
# Scheduler setup with warmup steps
num_training_steps = config["epochs"] * len(train_dataloader)
scheduler = get_scheduler(
    "cosine",                 # You could also try "linear" or "reduce_on_plateau"
    optimizer=optimizer,
    num_warmup_steps=config["warmup_steps"],
    num_training_steps=num_training_steps
)

In [None]:
# Evaluation function with metrics
def evaluate_model(model, processor, eval_dataloader):
    model.eval()
    total_preds, total_labels, total_wer, total_cer, num_samples = [], [], 0, 0, 0

    for batch in eval_dataloader:
        input_features = batch["input_features"].to(device)
        labels = batch["labels"].to(device)

        with torch.no_grad():
            generated_ids = model.generate(input_features)
            preds = processor.batch_decode(generated_ids, skip_special_tokens=True)
            refs = processor.batch_decode(labels, skip_special_tokens=True)

            total_wer += sum([wer(r, p) for r, p in zip(refs, preds)]) / len(refs)
            total_cer += sum([cer(r, p) for r, p in zip(refs, preds)]) / len(refs)
            num_samples += len(refs)

            total_preds.extend(preds)
            total_labels.extend(refs)

    avg_wer = total_wer / num_samples
    avg_cer = total_cer / num_samples
    precision, recall, f1, _ = precision_recall_fscore_support(total_labels, total_preds, average="weighted")
    accuracy = accuracy_score(total_labels, total_preds)

    return {
        "WER": avg_wer,
        "CER": avg_cer,
        "Precision": precision,
        "Recall": recall,
        "F1-Score": f1,
        "Accuracy": accuracy
    }

In [None]:
# Training loop with WandB logging, scheduler step, and checkpoint saving
eval_dataloader = DataLoader(ProcessData(audio_files, transcripts, timestamps, processor), batch_size=config["batch_size"])
metrics_before = evaluate_model(model, processor, eval_dataloader)
print("Metrics Before Training:", metrics_before)
wandb.log(metrics_before)

for epoch in range(config["epochs"]):
    model.train()
    total_loss = 0
    for batch in train_dataloader:
        input_features = batch["input_features"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_features, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        scheduler.step()  # Adjust learning rate
        optimizer.zero_grad()

    # Log epoch metrics and loss to WandB
    wandb.log({"loss": total_loss / len(train_dataloader), "epoch": epoch + 1})

    # Save checkpoint
    if (epoch + 1) % config["checkpoint_interval"] == 0:
        checkpoint_dir = f"/content/drive/MyDrive/whisper_checkpoints/checkpoint_epoch_{epoch + 1}"
        model.save_pretrained(checkpoint_dir)
        processor.save_pretrained(checkpoint_dir)

    print(f"Epoch {epoch + 1} completed. Loss: {total_loss / len(train_dataloader)}")

# Evaluation after training
metrics_after = evaluate_model(model, processor, eval_dataloader)
print("Metrics After Training:", metrics_after)
wandb.log(metrics_after)

In [None]:
# Plotting metrics comparison
def plot_metrics(metrics_before, metrics_after):
    metrics_names = list(metrics_before.keys())
    before_values = list(metrics_before.values())
    after_values = list(metrics_after.values())
    x = range(len(metrics_names))

    plt.figure(figsize=(12, 6))
    plt.bar(x, before_values, width=0.4, label="Before Training", color="skyblue", align="center")
    plt.bar(x, after_values, width=0.4, label="After Training", color="salmon", align="edge")
    plt.xlabel("Metrics")
    plt.ylabel("Scores")
    plt.title("Model Performance Before and After Fine-Tuning")
    plt.xticks(x, metrics_names)
    plt.legend()
    plt.show()

plot_metrics(metrics_before, metrics_after)
wandb.finish()