### Install required libraries

In [None]:
!pip install transformers torchaudio jiwer sklearn

### Import libraries

In [None]:
import os
import pandas as pd
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from jiwer import wer, cer
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import matplotlib.pyplot as plt
from torch.optim import AdamW

### Load data function provided

In [None]:
def load_data(tsv_file, audio_dir, max_samples=100):
    audio_files = []
    transcripts = []
    count = 0

    try:
        print("Loading dataset...\n\n" + "=" * 50 + "\n")
        df = pd.read_csv(tsv_file, sep='\t')
        df = df.sample(frac=1).reset_index(drop=True)

        for index, row in df.iterrows():
            audio_file = row['path']
            if not audio_file.endswith(".mp3"):
                audio_file += ".mp3"
            transcript = row['sentence']

            audio_files.append(os.path.join(audio_dir, audio_file))
            transcripts.append(transcript)
            count += 1

            if count >= max_samples:
                print(f"Finished loading {count} audio files and transcripts.\n\n" + "=" * 50 + "\n")
                break

        return audio_files, transcripts
    except Exception as e:
        print(f"Error loading Common Voice data: {e}\n")
        return [], []

### Dataset class

In [None]:
class ProcessData(Dataset):
    def __init__(self, audio_files, transcripts, processor):
        self.audio_files = audio_files
        self.transcripts = transcripts
        self.processor = processor

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        transcript = self.transcripts[idx]

        audio = torchaudio.load(audio_path)
        input_features = self.processor(audio, sampling_rate=16000, return_tensors="pt").input_features
        labels = self.processor.tokenizer(transcript, return_tensors="pt").input_ids

        return {"input_features": input_features.squeeze(), "labels": labels.squeeze()}

### Load Data and Prepare Dataloader

In [None]:
tsv_file = "/path/to/train.tsv"
audio_dir = "/path/to/audio_files"
audio_files, transcripts = load_data(tsv_file, audio_dir)

processor = WhisperProcessor.from_pretrained("openai/whisper-base")

train_dataset = ProcessData(audio_files, transcripts, processor)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

### Model and Optimizer Setup

In [None]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
model.resize_token_embeddings(len(processor.tokenizer))

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)

### Function to evaluate model

In [None]:
def evaluate_model(model, processor, eval_dataloader):
    model.eval()
    total_preds = []
    total_labels = []

    total_wer = 0
    total_cer = 0
    num_samples = 0

    for batch in eval_dataloader:
        input_features = batch["input_features"].to(device)
        labels = batch["labels"].to(device)

        with torch.no_grad():
            generated_ids = model.generate(input_features)
            preds = processor.batch_decode(generated_ids, skip_special_tokens=True)
            refs = processor.batch_decode(labels, skip_special_tokens=True)

            total_wer += sum([wer(r, p) for r, p in zip(refs, preds)]) / len(refs)
            total_cer += sum([cer(r, p) for r, p in zip(refs, preds)]) / len(refs)
            num_samples += len(refs)

            total_preds.extend(preds)
            total_labels.extend(refs)

    avg_wer = total_wer / num_samples
    avg_cer = total_cer / num_samples
    precision, recall, f1, _ = precision_recall_fscore_support(total_labels, total_preds, average="weighted")
    accuracy = accuracy_score(total_labels, total_preds)

    return {
        "WER": avg_wer,
        "CER": avg_cer,
        "Precision": precision,
        "Recall": recall,
        "F1-Score": f1,
        "Accuracy": accuracy
    }

### Function to plot metrics

In [None]:
def plot_metrics(metrics_before, metrics_after):
    metrics_names = list(metrics_before.keys())
    before_values = list(metrics_before.values())
    after_values = list(metrics_after.values())

    x = range(len(metrics_names))

    plt.figure(figsize=(12, 6))
    plt.bar(x, before_values, width=0.4, label="Before Training", color="skyblue", align="center")
    plt.bar(x, after_values, width=0.4, label="After Training", color="salmon", align="edge")

    plt.xlabel("Metrics")
    plt.ylabel("Scores")
    plt.title("Model Performance Before and After Fine-Tuning")
    plt.xticks(x, metrics_names)
    plt.legend()
    plt.show()


### Evaluate Model Before Training

In [None]:
eval_dataloader = DataLoader(ProcessData(audio_files, transcripts, processor), batch_size=4)

metrics_before = evaluate_model(model, processor, eval_dataloader)
print("Metrics Before Training:", metrics_before)

### Training Loop with Checkpoint Saving

In [None]:
epochs = 3
checkpoint_interval = 2

for epoch in range(epochs):
    model.train()
    for batch in train_dataloader:
        input_features = batch["input_features"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_features, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    # Save checkpoint
    if (epoch + 1) % checkpoint_interval == 0:
        model.save_pretrained(f"./checkpoint_epoch_{epoch + 1}")
        processor.save_pretrained(f"./processor_epoch_{epoch + 1}")

    print(f"Epoch {epoch + 1} completed. Loss: {loss.item()}")

### Evaluate Model After Training

In [None]:
metrics_after = evaluate_model(model, processor, eval_dataloader)
print("Metrics After Training:", metrics_after)

### Plot Performance Before and After Training

In [None]:
plot_metrics(metrics_before, metrics_after)