In [16]:
!pip install wandb transformers torch torchaudio jiwer scikit-learn



In [17]:
import os
import wandb
import pandas as pd
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import WhisperProcessor, WhisperForConditionalGeneration, get_scheduler
from jiwer import wer, cer
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import matplotlib.pyplot as plt
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from google.colab import drive

In [18]:
# Enhanced configurations
config = {
    "tsv_file": "/content/drive/Shareddrives/CS307-Thesis/Dataset/single-speaker/validated.tsv",
    "audio_dir": "/content/drive/Shareddrives/CS307-Thesis/Dataset/single-speaker/",
    "batch_size": 4,
    "learning_rate": 5e-5,
    "weight_decay": 0.01,
    "warmup_steps": 500,
    "epochs": 3,
    "max_samples": 100,
    "checkpoint_interval": 2,
    "validation_split": 0.1,
    "max_grad_norm": 1.0,
    "early_stopping_patience": 3,
    "mixed_precision": True,
    "gradient_accumulation_steps": 4
}

In [19]:
def load_data(tsv_file, audio_dir, max_samples=None):
    """
    Load data from TSV file with support for WAV, MP3, and FLAC formats.
    """
    audio_files, transcripts, timestamps = [], [], []

    # Read TSV file with explicit column names
    df = pd.read_csv(tsv_file, sep='\t')
    required_columns = ['path', 'start_time', 'end_time', 'sentence']

    # Verify all required columns are present
    if not all(col in df.columns for col in required_columns):
        raise ValueError(f"TSV file must contain columns: {required_columns}")

    # Shuffle the dataframe
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)

    # Limit samples if specified
    if max_samples:
        df = df.head(max_samples)

    # Define supported audio formats
    supported_formats = (".wav", ".mp3", ".flac")

    for _, row in df.iterrows():
        base_audio_file = row['path']

        # Look for file in supported formats
        for ext in supported_formats:
            audio_file = base_audio_file if base_audio_file.endswith(ext) else f"{base_audio_file}{ext}"
            full_audio_path = os.path.join(audio_dir, audio_file)

            if os.path.exists(full_audio_path):
                audio_files.append(full_audio_path)
                transcripts.append(row['sentence'])
                timestamps.append((float(row['start_time']), float(row['end_time'])))
                break
        else:
            print(f"Warning: Audio file not found for base name '{base_audio_file}' with supported formats.")

    return audio_files, transcripts, timestamps

In [20]:
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [21]:
class ProcessData(Dataset):
    def __init__(self, audio_files, transcripts, timestamps, processor, training=True):
        self.audio_files = audio_files
        self.transcripts = transcripts
        self.timestamps = timestamps
        self.processor = processor
        self.training = training

        self.audio_transforms = torch.nn.Sequential(
            torchaudio.transforms.TimeStretch(fixed_rate=0.98),
            torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
            torchaudio.transforms.TimeMasking(time_mask_param=100)
        ) if training else None

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        transcript = self.transcripts[idx]
        start_time, end_time = self.timestamps[idx]

        try:
            audio, sample_rate = torchaudio.load(audio_path)

            # Resample if necessary
            if sample_rate != 16000:
                resampler = torchaudio.transforms.Resample(sample_rate, 16000)
                audio = resampler(audio)

            # Apply audio augmentation during training
            if self.training and self.audio_transforms:
                audio = self.audio_transforms(audio)

            # Apply timestamps to trim audio
            start_frame = int(start_time * 16000)
            end_frame = int(end_time * 16000)
            audio = audio[:, start_frame:end_frame]

            input_features = self.processor(
                audio.squeeze().numpy(),
                sampling_rate=16000,
                return_tensors="pt"
            ).input_features

            labels = self.processor(
                transcript,
                return_tensors="pt"
            ).input_ids

            return {
                "input_features": input_features.squeeze(),
                "labels": labels.squeeze()
            }

        except Exception as e:
            print(f"Error processing {audio_path}: {str(e)}")
            # Return a zero tensor with appropriate shape as fallback
            return {
                "input_features": torch.zeros(80, 3000),
                "labels": torch.zeros(100)
            }

In [22]:
def create_dataloaders(audio_files, transcripts, timestamps, processor, config):
    # Create full dataset
    full_dataset = ProcessData(audio_files, transcripts, timestamps, processor)

    # Calculate split sizes
    val_size = int(len(full_dataset) * config["validation_split"])
    train_size = len(full_dataset) - val_size

    # Split dataset
    train_dataset, val_dataset = random_split(
        full_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )

    # Create training dataloader with shuffling
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    # Create validation dataloader without shuffling
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    return train_dataloader, val_dataloader

In [23]:
def evaluate_model(model, processor, dataloader, device):
    """
    Evaluate the model using multiple metrics:
    - Loss
    - Word Error Rate (WER)
    - Character Error Rate (CER)
    - Precision
    - Recall
    - F1-Score
    - Accuracy
    """
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_features = batch["input_features"].to(device)
            labels = batch["labels"].to(device)

            # Get model outputs and loss
            outputs = model(input_features, labels=labels)
            total_loss += outputs.loss.item()

            # Generate predictions
            generated_ids = model.generate(input_features)

            # Decode predictions and labels
            preds = processor.batch_decode(generated_ids, skip_special_tokens=True)
            refs = processor.batch_decode(labels, skip_special_tokens=True)

            # Store predictions and references for metric calculation
            all_preds.extend(preds)
            all_labels.extend(refs)

    # Calculate average loss
    avg_loss = total_loss / len(dataloader)

    # Calculate WER and CER
    wer_score = wer(all_labels, all_preds)
    cer_score = cer(all_labels, all_preds)

    # Calculate precision, recall, and F1-score
    # We'll treat each word as a token for these metrics
    tokenized_labels = [text.split() for text in all_labels]
    tokenized_preds = [text.split() for text in all_preds]

    # Flatten the lists for sklearn metrics
    flat_labels = [word for sentence in tokenized_labels for word in sentence]
    flat_preds = [word for sentence in tokenized_preds for word in sentence]

    # Calculate precision, recall, F1
    precision, recall, f1, _ = precision_recall_fscore_support(
        flat_labels,
        flat_preds,
        average='weighted',
        zero_division=0
    )

    # Calculate accuracy
    accuracy = accuracy_score(flat_labels, flat_preds)

    # Compile all metrics
    metrics = {
        "loss": avg_loss,
        "WER": wer_score,
        "CER": cer_score,
        "Precision": precision,
        "Recall": recall,
        "F1-Score": f1,
        "Accuracy": accuracy
    }

    # Print metrics for monitoring
    print("\nEvaluation Metrics:")
    for metric_name, value in metrics.items():
        print(f"{metric_name}: {value:.4f}")

    return metrics

In [24]:
def train_model(config):
    # Initialize wandb
    wandb.init(project="taglish-whisper-fine-tuning", config=config)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model and processor
    processor = WhisperProcessor.from_pretrained("openai/whisper-base")
    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
    model.to(device)

    # Load data
    audio_files, transcripts, timestamps = load_data(
        config["tsv_file"],
        config["audio_dir"],
        config["max_samples"]
    )

    # Create dataloaders
    train_dataloader, val_dataloader = create_dataloaders(
        audio_files,
        transcripts,
        timestamps,
        processor,
        config
    )

    # Setup optimizer and scheduler
    optimizer = AdamW(
        model.parameters(),
        lr=config["learning_rate"],
        weight_decay=config["weight_decay"]
    )

    num_training_steps = config["epochs"] * len(train_dataloader)
    scheduler = get_scheduler(
        "cosine",
        optimizer=optimizer,
        num_warmup_steps=config["warmup_steps"],
        num_training_steps=num_training_steps
    )

    # Setup mixed precision training
    scaler = torch.cuda.amp.GradScaler() if config["mixed_precision"] else None

    # Setup early stopping
    early_stopping = EarlyStopping(patience=config["early_stopping_patience"])

    # Training loop
    for epoch in range(config["epochs"]):
        model.train()
        total_loss = 0
        optimizer.zero_grad()

        for i, batch in enumerate(train_dataloader):
            input_features = batch["input_features"].to(device)
            labels = batch["labels"].to(device)

            # Mixed precision training
            if config["mixed_precision"]:
                with torch.cuda.amp.autocast():
                    outputs = model(input_features, labels=labels)
                    loss = outputs.loss / config["gradient_accumulation_steps"]

                scaler.scale(loss).backward()

                if (i + 1) % config["gradient_accumulation_steps"] == 0:
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), config["max_grad_norm"])
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
            else:
                outputs = model(input_features, labels=labels)
                loss = outputs.loss / config["gradient_accumulation_steps"]
                loss.backward()

                if (i + 1) % config["gradient_accumulation_steps"] == 0:
                    clip_grad_norm_(model.parameters(), config["max_grad_norm"])
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

            total_loss += loss.item()

            # Log batch progress
            if i % 10 == 0:
                wandb.log({
                    "batch_loss": loss.item(),
                    "learning_rate": scheduler.get_last_lr()[0]
                })

        # Validation phase with complete metrics
        val_metrics = evaluate_model(model, processor, val_dataloader, device)

        # Log all metrics to WandB
        wandb.log({
            "train_loss": total_loss / len(train_dataloader),
            "val_loss": val_metrics["loss"],
            "val_wer": val_metrics["WER"],
            "val_cer": val_metrics["CER"],
            "val_precision": val_metrics["Precision"],
            "val_recall": val_metrics["Recall"],
            "val_f1": val_metrics["F1-Score"],
            "val_accuracy": val_metrics["Accuracy"],
            "epoch": epoch + 1
        })

        # Early stopping check
        early_stopping(val_metrics["loss"])
        if early_stopping.should_stop:
            print(f"Early stopping triggered at epoch {epoch + 1}")
            break

        # Save checkpoint
        if (epoch + 1) % config["checkpoint_interval"] == 0:
            checkpoint_dir = f"/content/drive/Shareddrives/CS307-Thesis/Dataset/whisper_checkpoints/checkpoint_epoch_{epoch + 1}"
            os.makedirs(checkpoint_dir, exist_ok=True)

            # Save model and processor
            model.save_pretrained(checkpoint_dir)
            processor.save_pretrained(checkpoint_dir)

            # Save training state
            torch.save({
                'epoch': epoch,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'metrics': val_metrics
            }, os.path.join(checkpoint_dir, 'training_state.pt'))

    wandb.finish()

In [25]:
if __name__ == "__main__":
    drive.mount('/content/drive')
    train_model(config)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).




ValueError: num_samples should be a positive integer value, but got num_samples=0