In [None]:
!pip install numpy pandas torch transformers wandb tqdm scikit-learn librosa

In [None]:
import os
import torch
import wandb
import pandas as pd
from transformers import WhisperForConditionalGeneration, WhisperProcessor, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from datasets import Dataset, load_metric
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Initialize WandB
wandb.init(project="taglish-whisper-finetuning")

# Define training configurations
MODEL_NAME = "openai/whisper-small"
MODEL_PATH = "/content/gdrive/Shareddrives/CS307-Thesis/Dataset/whisper_checkpoints/"
AUDIO_DIR = "/content/gdrive/Shareddrives/CS307-Thesis/Dataset/single-speaker/"
TSV_FILE = "/content/gdrive/Shareddrives/CS307-Thesis/Dataset/single-speaker/validated.tsv"
OUTPUT_DIR = "/content/gdrive/Shareddrives/CS307-Thesis/Dataset/whisper_output/"
EVAL_METRICS = ["wer", "cer", "accuracy", "precision", "recall", "f1"]

# Load processor and model
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)

# Check for GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Data loading function
def load_data(tsv_file, audio_dir, max_samples=None):
    """
    Load data from TSV file with timestamp handling, compatible with both "sec" and "min:sec" formats.
    """
    audio_files, transcripts, languages, timestamps = [], [], [], []

    # Read TSV file
    df = pd.read_csv(tsv_file, sep='\t')
    required_columns = ['path', 'start_time', 'end_time', 'language', 'sentence']

    # Verify all required columns are present
    if not all(col in df.columns for col in required_columns):
        raise ValueError(f"TSV file must contain columns: {required_columns}")

    # Shuffle and limit samples if specified
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    if max_samples:
        df = df.head(max_samples)

    for _, row in df.iterrows():
        audio_file = row['path']
        if not audio_file.endswith((".mp3", ".wav", ".flac")):
            print(f"Skipping unsupported file type: {audio_file}")
            continue

        full_audio_path = os.path.join(audio_dir, audio_file)
        if not os.path.exists(full_audio_path):
            print(f"Warning: Audio file not found: {full_audio_path}")
            continue

        # Parse timestamps
        def parse_time(time_str):
            try:
                return float(time_str)
            except ValueError:
                minutes, seconds = map(float, time_str.split(":"))
                return minutes * 60 + seconds

        try:
            start_time = parse_time(row['start_time'])
            end_time = parse_time(row['end_time'])
        except Exception as e:
            print(f"Error parsing timestamps for {audio_file}: {str(e)}")
            continue

        audio_files.append(full_audio_path)
        transcripts.append(row['sentence'])
        timestamps.append((start_time, end_time))
        languages.append(row['language'])

    return audio_files, transcripts, languages, timestamps

# Load dataset
audio_files, transcripts, languages, timestamps = load_data(TSV_FILE, AUDIO_DIR)

# Data preparation for Trainer
dataset = Dataset.from_dict({
    "input_values": [processor(audio_file, sampling_rate=16000).input_values[0] for audio_file in audio_files],
    "labels": transcripts
})

# Evaluation metrics
wer_metric = load_metric("wer")
cer_metric = load_metric("cer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    accuracy = accuracy_score(label_str, pred_str)
    precision, recall, f1, _ = precision_recall_fscore_support(label_str, pred_str, average="weighted")

    return {
        "wer": wer,
        "cer": cer,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    evaluation_strategy="epoch",
    per_device_train_batch_size=7,
    per_device_eval_batch_size=6,
    num_train_epochs=3,
    save_steps=500,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    report_to="wandb"
)

# Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,
    tokenizer=processor.tokenizer,
    compute_metrics=compute_metrics
)

# Train and save the model
trainer.train()
trainer.save_model(MODEL_PATH)
processor.save_pretrained(MODEL_PATH)
wandb.finish()
