In [None]:
!pip install wandb transformers torch torchaudio jiwer scikit-learn



In [None]:
import os
import wandb
import pandas as pd
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import WhisperProcessor, WhisperForConditionalGeneration, get_scheduler
from jiwer import wer, cer
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import matplotlib.pyplot as plt
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
import torch.nn.functional as F
from google.colab import drive

In [None]:
# Enhanced configurations
config = {
    "tsv_file": "/content/drive/Shareddrives/CS307-Thesis/Dataset/single-speaker/validated.tsv",
    "audio_dir": "/content/drive/Shareddrives/CS307-Thesis/Dataset/single-speaker/",
    "batch_size": 4,
    "learning_rate": 5e-5,
    "weight_decay": 0.01,
    "warmup_steps": 500,
    "epochs": 3,
    "max_samples": 9,
    "checkpoint_interval": 2,
    "validation_split": 0.1,
    "max_grad_norm": 1.0,
    "early_stopping_patience": 3,
    "mixed_precision": True,
    "gradient_accumulation_steps": 4
}

In [None]:
def load_data(tsv_file, audio_dir, max_samples=None):
    """
    Load data from TSV file with timestamp handling, compatible with both "sec" and "min:sec" formats.
    """
    audio_files, transcripts, timestamps = [], [], []

    # Read TSV file
    df = pd.read_csv(tsv_file, sep='\t')
    required_columns = ['path', 'start_time', 'end_time', 'sentence']

    # Verify all required columns are present
    if not all(col in df.columns for col in required_columns):
        raise ValueError(f"TSV file must contain columns: {required_columns}")

    # Shuffle and limit samples if specified
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    if max_samples:
        df = df.head(max_samples)

    for _, row in df.iterrows():
        audio_file = row['path']
        if not audio_file.endswith((".mp3", ".wav", ".flac")):
            print(f"Skipping unsupported file type: {audio_file}")
            continue

        full_audio_path = os.path.join(audio_dir, audio_file)
        if not os.path.exists(full_audio_path):
            print(f"Warning: Audio file not found: {full_audio_path}")
            continue

        # Parse timestamps
        def parse_time(time_str):
            try:
                # Check if time is already in seconds
                return float(time_str)
            except ValueError:
                # Convert from "min:sec" format to seconds
                minutes, seconds = map(float, time_str.split(":"))
                return minutes * 60 + seconds

        try:
            start_time = parse_time(row['start_time'])
            end_time = parse_time(row['end_time'])
        except Exception as e:
            print(f"Error parsing timestamps for {audio_file}: {str(e)}")
            continue

        audio_files.append(full_audio_path)
        transcripts.append(row['sentence'])
        timestamps.append((start_time, end_time))

    return audio_files, transcripts, timestamps

In [None]:
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [None]:
def collate_fn(batch):
    """
    Custom collate function to handle variable length sequences
    """
    # Filter out any None or empty items
    batch = [item for item in batch if item is not None]

    if len(batch) == 0:
        return {}

    # Get maximum lengths
    max_input_length = max(item['input_features'].size(1) for item in batch)
    max_label_length = max(item['labels'].size(0) for item in batch)

    # Initialize padded tensors
    batch_size = len(batch)
    padded_input_features = torch.zeros(batch_size, 80, max_input_length)
    padded_labels = torch.full((batch_size, max_label_length), -100, dtype=torch.long)  # -100 is often used for padding in transformers

    # Fill padded tensors
    for i, item in enumerate(batch):
        # Pad input features
        input_features = item['input_features']
        length = input_features.size(1)
        padded_input_features[i, :, :length] = input_features

        # Pad labels
        labels = item['labels']
        length = labels.size(0)
        padded_labels[i, :length] = labels

    return {
        'input_features': padded_input_features,
        'labels': padded_labels
    }

In [None]:
class ProcessData(Dataset):
    def __init__(self, audio_files, transcripts, timestamps, processor, training=True):
        self.audio_files = audio_files
        self.transcripts = transcripts
        self.timestamps = timestamps
        self.processor = processor
        self.training = training

        # Initialize transforms for training
        if training:
            self.time_stretch = torchaudio.transforms.TimeStretch(fixed_rate=0.98)
            self.freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param=30)
            self.time_mask = torchaudio.transforms.TimeMasking(time_mask_param=100)

    def apply_audio_transforms(self, audio):
        """Apply audio augmentation transforms with proper complex conversion"""
        if self.training:
            # Convert to complex representation for TimeStretch
            spec = torch.stft(
                audio,
                n_fft=400,
                hop_length=100,
                win_length=400,
                window=torch.hann_window(400),
                return_complex=True
            )

            # Apply TimeStretch
            spec = self.time_stretch(spec)

            # Convert back to time domain
            audio = torch.istft(
                spec,
                n_fft=400,
                hop_length=100,
                win_length=400,
                window=torch.hann_window(400)
            )

            # Apply other transforms
            audio = self.freq_mask(audio.unsqueeze(0)).squeeze(0)
            audio = self.time_mask(audio.unsqueeze(0)).squeeze(0)

        return audio

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        transcript = self.transcripts[idx]
        start_time, end_time = self.timestamps[idx]

        try:
            # Load audio
            audio, sample_rate = torchaudio.load(audio_path)

            # Convert to mono if stereo
            if audio.shape[0] > 1:
                audio = torch.mean(audio, dim=0, keepdim=True)

            # Resample if necessary
            if sample_rate != 16000:
                resampler = torchaudio.transforms.Resample(sample_rate, 16000)
                audio = resampler(audio)

            # Apply timestamps to trim audio
            start_frame = int(start_time * 16000)
            end_frame = int(end_time * 16000)
            audio = audio[:, start_frame:end_frame]

            # Apply audio transforms if in training mode
            if self.training:
                audio = self.apply_audio_transforms(audio.squeeze())
                audio = audio.unsqueeze(0)

            # Process audio features
            input_features = self.processor(
                audio.squeeze().numpy(),
                sampling_rate=16000,
                return_tensors="pt"
            ).input_features

            # Process labels and ensure correct type
            labels = self.processor(
                text=transcript,
                return_tensors="pt"
            ).input_ids.squeeze()

            # Ensure labels are of type Long
            labels = labels.long()

            return {
                "input_features": input_features.squeeze(),
                "labels": labels
            }

        except Exception as e:
            print(f"Error processing {audio_path}: {str(e)}")
            return None

In [None]:
def create_dataloaders(audio_files, transcripts, timestamps, processor, config):
    # Create full dataset
    full_dataset = ProcessData(audio_files, transcripts, timestamps, processor)

    # Calculate split sizes
    val_size = int(len(full_dataset) * config["validation_split"])
    train_size = len(full_dataset) - val_size

    # Split dataset
    train_dataset, val_dataset = random_split(
        full_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )

    # Create training dataloader with shuffling and custom collate_fn
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn
    )

    # Create validation dataloader without shuffling
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn
    )

    return train_dataloader, val_dataloader

In [None]:
def evaluate_model(model, processor, dataloader, device):
    """
    Evaluate the model using multiple metrics
    """
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            # Skip empty batches
            if not batch:
                continue

            input_features = batch["input_features"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass without forced_decoder_ids
            outputs = model(
                input_features,
                labels=labels
            )

            total_loss += outputs.loss.item()

            # Generate predictions with forced language and task
            generated_ids = model.generate(
                input_features,
                forced_decoder_ids=processor.get_decoder_prompt_ids(language="en", task="transcribe"),
                max_length=256  # Adjust as needed
            )

            # Decode predictions and labels
            decoded_preds = processor.batch_decode(generated_ids, skip_special_tokens=True)
            # Mask out padding tokens (-100) in labels before decoding
            labels_clean = torch.where(labels != -100, labels, processor.tokenizer.pad_token_id)
            decoded_labels = processor.batch_decode(labels_clean, skip_special_tokens=True)

            # Store non-empty predictions and labels
            for pred, label in zip(decoded_preds, decoded_labels):
                if pred.strip() and label.strip():  # Only store non-empty strings
                    all_preds.append(pred)
                    all_labels.append(label)

    # Calculate average loss
    avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else float('inf')

    # Only calculate other metrics if we have predictions
    if len(all_preds) > 0 and len(all_labels) > 0:
        # Calculate WER and CER
        wer_score = wer(all_labels, all_preds)
        cer_score = cer(all_labels, all_preds)

        # Tokenize for word-level metrics
        tokenized_labels = [text.split() for text in all_labels]
        tokenized_preds = [text.split() for text in all_preds]

        # Ensure equal length for calculating precision/recall
        min_len = min(len(tokenized_labels), len(tokenized_preds))
        tokenized_labels = tokenized_labels[:min_len]
        tokenized_preds = tokenized_preds[:min_len]

        # Flatten the lists
        flat_labels = [word for sentence in tokenized_labels for word in sentence]
        flat_preds = [word for sentence in tokenized_preds for word in sentence]

        # Calculate precision, recall, F1 only if we have matching samples
        if len(flat_labels) == len(flat_preds) and len(flat_labels) > 0:
            precision, recall, f1, _ = precision_recall_fscore_support(
                flat_labels,
                flat_preds,
                average='weighted',
                zero_division=0
            )
            accuracy = accuracy_score(flat_labels, flat_preds)
        else:
            precision = recall = f1 = accuracy = 0.0
    else:
        wer_score = cer_score = precision = recall = f1 = accuracy = 0.0

    metrics = {
        "loss": avg_loss,
        "WER": wer_score,
        "CER": cer_score,
        "Precision": precision,
        "Recall": recall,
        "F1-Score": f1,
        "Accuracy": accuracy
    }

    return metrics

In [None]:
def train_model(config, checkpoint_path=None):
    # Initialize wandb
    wandb.init(project="taglish-whisper-fine-tuning", config=config)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model and processor
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        processor = WhisperProcessor.from_pretrained(checkpoint_path)
        model = WhisperForConditionalGeneration.from_pretrained(checkpoint_path)

        # Load training state
        training_state = torch.load(os.path.join(checkpoint_path, 'training_state.pt'))
        start_epoch = training_state['epoch'] + 1
        print(f"Resuming from epoch {start_epoch}")
    else:
        print("Starting fresh training with base model")
        processor = WhisperProcessor.from_pretrained("openai/whisper-base")
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
        start_epoch = 0

    model.to(device)

    # Load data
    audio_files, transcripts, timestamps = load_data(
        config["tsv_file"],
        config["audio_dir"],
        config["max_samples"]
    )

    # Create dataloaders
    train_dataloader, val_dataloader = create_dataloaders(
        audio_files,
        transcripts,
        timestamps,
        processor,
        config
    )

    # Setup optimizer and scheduler
    optimizer = AdamW(
        model.parameters(),
        lr=config["learning_rate"],
        weight_decay=config["weight_decay"]
    )

    num_training_steps = config["epochs"] * len(train_dataloader)
    scheduler = get_scheduler(
        "cosine",
        optimizer=optimizer,
        num_warmup_steps=config["warmup_steps"],
        num_training_steps=num_training_steps
    )

    # Load optimizer and scheduler states if resuming
    if checkpoint_path and os.path.exists(checkpoint_path):
        optimizer.load_state_dict(training_state['optimizer_state_dict'])
        scheduler.load_state_dict(training_state['scheduler_state_dict'])

    # Setup mixed precision training
    scaler = torch.amp.GradScaler() if config["mixed_precision"] else None

    # Setup early stopping
    early_stopping = EarlyStopping(patience=config["early_stopping_patience"])

    # Training loop
    for epoch in range(start_epoch, config["epochs"]):
        model.train()
        total_loss = 0
        optimizer.zero_grad()

        for i, batch in enumerate(train_dataloader):
            # Skip empty batches
            if not batch:
                continue

            input_features = batch["input_features"].to(device)
            labels = batch["labels"].to(device)

            # Mixed precision training
            if config["mixed_precision"]:
                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    outputs = model(
                        input_features,
                        labels=labels,
                        forced_decoder_ids=processor.get_decoder_prompt_ids(language="en", task="transcribe")
                    )
                    loss = outputs.loss / config["gradient_accumulation_steps"]

                scaler.scale(loss).backward()

                if (i + 1) % config["gradient_accumulation_steps"] == 0:
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), config["max_grad_norm"])
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
            else:
                outputs = model(
                    input_features,
                    labels=labels,
                    forced_decoder_ids=processor.get_decoder_prompt_ids(language="en", task="transcribe")
                )
                loss = outputs.loss / config["gradient_accumulation_steps"]
                loss.backward()

                if (i + 1) % config["gradient_accumulation_steps"] == 0:
                    clip_grad_norm_(model.parameters(), config["max_grad_norm"])
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

            total_loss += loss.item()

            # Log batch progress
            if i % 10 == 0:
                wandb.log({
                    "batch_loss": loss.item(),
                    "learning_rate": scheduler.get_last_lr()[0]
                })

        # Validation phase
        val_metrics = evaluate_model(model, processor, val_dataloader, device)

        # Log metrics
        wandb.log({
            "train_loss": total_loss / len(train_dataloader),
            **val_metrics,
            "epoch": epoch + 1
        })

        # Early stopping check
        early_stopping(val_metrics["loss"])
        if early_stopping.should_stop:
            print(f"Early stopping triggered at epoch {epoch + 1}")
            break

        # Save checkpoint
        if (epoch + 1) % config["checkpoint_interval"] == 0:
            checkpoint_dir = f"/content/drive/Shareddrives/CS307-Thesis/Dataset/whisper_checkpoints/checkpoint_epoch_{epoch + 1}"
            os.makedirs(checkpoint_dir, exist_ok=True)

            model.save_pretrained(checkpoint_dir)
            processor.save_pretrained(checkpoint_dir)

            torch.save({
                'epoch': epoch,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'metrics': val_metrics
            }, os.path.join(checkpoint_dir, 'training_state.pt'))

    wandb.finish()

In [None]:
if __name__ == "__main__":
    drive.mount('/content/drive')

    # To resume from a checkpoint, specify the checkpoint path
    checkpoint_path = "/content/drive/Shareddrives/CS307-Thesis/Dataset/whisper_checkpoints/checkpoint_epoch_2"  # Replace X with the epoch number
    train_model(config, checkpoint_path=checkpoint_path)

    # Or to start fresh training:
    # train_model(config)

[34m[1mwandb[0m: Currently logged in as: [33midcsalvame[0m ([33midcsalvame-n-a[0m). Use [1m`wandb login --relogin`[0m to force relogin


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


0,1
Accuracy,▁▁▁
CER,▁▁▁
F1-Score,▁▁▁
Precision,▁▁▁
Recall,▁▁▁
WER,▁▁▁
batch_loss,▁▄█
epoch,▁▅█
learning_rate,▁▁▁
train_loss,▃▁█

0,1
Accuracy,0.0
CER,0.0
F1-Score,0.0
Precision,0.0
Recall,0.0
WER,0.0
batch_loss,2.16708
epoch,3.0
learning_rate,0.0
loss,inf
