In [None]:
!pip install wandb transformers torch torchaudio jiwer scikit-learn

In [88]:
import os
import wandb
import pandas as pd
import torch
import torchaudio
import logging
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import WhisperProcessor, WhisperForConditionalGeneration, get_scheduler
from jiwer import wer, cer
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
import json
from datetime import datetime

In [89]:
# Add GPU verification at startup
def verify_gpu_status():
    """Verify GPU availability and PyTorch CUDA configuration"""
    print("\n=== GPU Status Check ===")
    print(f"CUDA Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"Current Device: {torch.cuda.current_device()}")
        print(f"Device Name: {torch.cuda.get_device_name()}")
        print(f"Device Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        print(f"PyTorch CUDA Version: {torch.version.cuda}")
    print("=====================\n")

In [90]:
# Enhanced configurations
config = {
    "tsv_file": "/content/drive/Shareddrives/CS307-Thesis/Dataset/single-speaker/validated.tsv",
    "audio_dir": "/content/drive/Shareddrives/CS307-Thesis/Dataset/single-speaker/",
    "batch_size": 4,
    "learning_rate": 5e-5,
    "weight_decay": 0.01,
    "warmup_steps": 500,
    "epochs": 3,
    "max_samples": 13,
    "checkpoint_interval": 3,
    "validation_split": 0.1,
    "max_grad_norm": 1.0,
    "early_stopping_patience": 3,
    "mixed_precision": True,
    "gradient_accumulation_steps": 4
}

In [91]:
def validate_audio_file(audio_path):
    if not os.path.exists(audio_path):
        return False, "File does not exist"
    if not audio_path.endswith(('.wav', '.mp3', '.flac')):
        return False, "Unsupported audio format"
    return True, "Valid"

In [92]:
def load_data(tsv_file, audio_dir, max_samples=None):
    """
    Load data from TSV file with timestamp handling, compatible with both "sec" and "min:sec" formats.
    """
    audio_files, transcripts, languages, timestamps = [], [], [], []

    # Read TSV file
    df = pd.read_csv(tsv_file, sep='\t')
    required_columns = ['path', 'start_time', 'end_time', 'language', 'sentence']

    # Verify all required columns are present
    if not all(col in df.columns for col in required_columns):
        raise ValueError(f"TSV file must contain columns: {required_columns}")

    # Shuffle and limit samples if specified
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    if max_samples:
        df = df.head(max_samples)

    for _, row in df.iterrows():
        audio_file = row['path']
        if not audio_file.endswith((".mp3", ".wav", ".flac")):
            print(f"Skipping unsupported file type: {audio_file}")
            continue

        full_audio_path = os.path.join(audio_dir, audio_file)
        if not os.path.exists(full_audio_path):
            print(f"Warning: Audio file not found: {full_audio_path}")
            continue

        # Parse timestamps
        def parse_time(time_str):
            try:
                # Check if time is already in seconds
                return float(time_str)
            except ValueError:
                # Convert from "min:sec" format to seconds
                minutes, seconds = map(float, time_str.split(":"))
                return minutes * 60 + seconds

        try:
            start_time = parse_time(row['start_time'])
            end_time = parse_time(row['end_time'])
        except Exception as e:
            print(f"Error parsing timestamps for {audio_file}: {str(e)}")
            continue

        audio_files.append(full_audio_path)
        transcripts.append(row['sentence'])
        timestamps.append((start_time, end_time))
        languages.append(row['language'])

    return audio_files, transcripts, languages, timestamps

In [93]:
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [94]:
def collate_fn(batch):
    """
    Custom collate function to handle variable length sequences
    """
    # Filter out any None or empty items
    batch = [item for item in batch if item is not None]

    if len(batch) == 0:
        return {}

    # Get maximum lengths
    max_input_length = max(item['input_features'].size(1) for item in batch)
    max_label_length = max(item['labels'].size(0) for item in batch)

    # Initialize padded tensors
    batch_size = len(batch)
    padded_input_features = torch.zeros(batch_size, 80, max_input_length)
    padded_labels = torch.full((batch_size, max_label_length), -100, dtype=torch.long)  # -100 is often used for padding in transformers

    # Fill padded tensors
    for i, item in enumerate(batch):
        # Pad input features
        input_features = item['input_features']
        length = input_features.size(1)
        padded_input_features[i, :, :length] = input_features

        # Pad labels
        labels = item['labels']
        length = labels.size(0)
        padded_labels[i, :length] = labels

    return {
        'input_features': padded_input_features,
        'labels': padded_labels
    }

In [95]:
class ProcessData(Dataset):
    def __init__(self, audio_files, transcripts, timestamps, processor, languages=None, training=True):
        self.audio_files = audio_files
        self.transcripts = transcripts
        self.timestamps = timestamps
        self.processor = processor
        self.languages = languages or ["tl-en"] * len(audio_files)  # Default to tl-en if not provided
        self.training = training
        self.debug_stats = {
            "processed": 0,
            "errors": 0,
            "error_types": {}
        }

        print(f"Initializing dataset with {len(audio_files)} samples")

        if training:
            # Initialize audio augmentation transforms
            self.time_stretch = torchaudio.transforms.TimeStretch(fixed_rate=0.98)
            self.freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param=30)
            self.time_mask = torchaudio.transforms.TimeMasking(time_mask_param=100)

    def apply_audio_transforms(self, audio):
        """Apply audio augmentation transforms during training"""
        if not self.training:
            return audio

        try:
            # Convert to complex spectrogram for time stretching
            spec = torch.stft(
                audio,
                n_fft=400,
                hop_length=100,
                win_length=400,
                window=torch.hann_window(400),
                return_complex=True
            )

            # Apply time stretch
            spec_stretched = self.time_stretch(spec)

            # Convert back to time domain
            audio = torch.istft(
                spec_stretched,
                n_fft=400,
                hop_length=100,
                win_length=400,
                window=torch.hann_window(400)
            )

            # Apply frequency and time masking
            spec = torchaudio.transforms.MelSpectrogram()(audio)
            spec = self.freq_mask(spec)
            spec = self.time_mask(spec)

            return audio
        except Exception as e:
            print(f"Warning: Audio augmentation failed: {str(e)}")
            return audio

    def log_processing_step(self, idx, step, info):
        """Structured logging for processing steps"""
        print(f"\n[Sample {idx}][{step}] {info}")

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        try:
            audio_path = self.audio_files[idx]
            transcript = self.transcripts[idx]
            start_time, end_time = self.timestamps[idx]

            self.log_processing_step(idx, "Start", "Beginning processing")
            self.log_processing_step(idx, "Info", f"Path: {audio_path}")
            self.log_processing_step(idx, "Info", f"Transcript: {transcript}")
            self.log_processing_step(idx, "Info", f"Timestamps: {start_time}-{end_time}")

            # Load audio
            audio, sample_rate = torchaudio.load(audio_path)
            self.log_processing_step(idx, "Audio Load", f"Shape: {audio.shape}, Sample rate: {sample_rate}")

            # Convert to mono if stereo
            if audio.shape[0] > 1:
                audio = torch.mean(audio, dim=0, keepdim=True)
                self.log_processing_step(idx, "Mono Convert", "Converted stereo to mono")

            # Resample if necessary
            if sample_rate != 16000:
                resampler = torchaudio.transforms.Resample(sample_rate, 16000)
                audio = resampler(audio)
                self.log_processing_step(idx, "Resample", "Resampled to 16kHz")

            # Apply timestamps
            start_frame = int(start_time * 16000)
            end_frame = int(end_time * 16000)
            audio = audio[:, start_frame:end_frame]
            self.log_processing_step(idx, "Trim", f"Trimmed shape: {audio.shape}")

            # Apply audio transforms if in training mode
            if self.training:
                audio = self.apply_audio_transforms(audio.squeeze())
                audio = audio.unsqueeze(0)
                self.log_processing_step(idx, "Augment", "Applied audio transforms")

            # Process audio features
            input_features = self.processor(
                audio.squeeze().numpy(),
                sampling_rate=16000,
                return_tensors="pt"
            ).input_features
            self.log_processing_step(idx, "Features", f"Processed features shape: {input_features.shape}")

            # Process labels
            labels = self.processor(
                text=transcript,
                return_tensors="pt"
            ).input_ids.squeeze()
            self.log_processing_step(idx, "Labels", f"Processed labels shape: {labels.shape}")

            self.debug_stats["processed"] += 1
            return {
                "input_features": input_features.squeeze(),
                "labels": labels.long(),
                "transcript": transcript,
                "language": self.languages[idx]  # Add language information to the batch
            }

        except Exception as e:
            self.log_error(type(e).__name__, str(e))
            self.log_processing_step(idx, "Error", f"Failed: {str(e)}")
            return None

    def log_error(self, error_type, details):
        if error_type not in self.debug_stats["error_types"]:
            self.debug_stats["error_types"][error_type] = []
        self.debug_stats["error_types"][error_type].append(details)
        self.debug_stats["errors"] += 1

    def get_stats(self):
        return {
            "total_items": len(self),
            "successfully_processed": self.debug_stats["processed"],
            "errors": self.debug_stats["errors"],
            "error_breakdown": self.debug_stats["error_types"]
        }

In [96]:
def create_dataloaders(audio_files, transcripts, timestamps, processor, config, languages=None):
    """Create training and validation dataloaders with error checking"""
    # Create full dataset
    full_dataset = ProcessData(audio_files, transcripts, timestamps, processor, languages=languages, training=True)

    # Calculate split sizes with minimum validation size check
    total_samples = len(full_dataset)
    val_size = max(int(total_samples * config["validation_split"]), 1)
    train_size = total_samples - val_size

    print(f"Splitting dataset: {train_size} training samples, {val_size} validation samples")

    # Split dataset
    train_dataset, val_dataset = random_split(
        full_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )

    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn
    )

    return train_dataloader, val_dataloader

In [97]:
def evaluate_model(model, processor, dataloader, device, log_samples=5):
    """Enhanced evaluation function with multilingual support and detailed prediction logging"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    batch_sizes = []
    samples_logged = 0

    # Create dictionaries to store metrics by language
    metrics_by_language = {
        "en": {"preds": [], "labels": []},
        "tl": {"preds": [], "labels": []},
        "tl-en": {"preds": [], "labels": []}
    }

    print("\n=== Starting Evaluation ===")

    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            if not batch:
                print(f"Batch {batch_idx} is empty, skipping...")
                continue

            batch_sizes.append(batch["input_features"].size(0))
            input_features = batch["input_features"].to(device)
            labels = batch["labels"].to(device)

            # Get language information from batch
            languages = batch.get("language", ["tl-en"] * input_features.size(0))

            try:
                # Create forced decoder IDs for both English and Tagalog
                forced_decoder_ids = processor.get_decoder_prompt_ids(
                    language=["en", "tl"],
                    task="transcribe",
                    no_timestamps=True
                )

                outputs = model(
                    input_features,
                    labels=labels,
                    forced_decoder_ids=forced_decoder_ids
                )

                total_loss += outputs.loss.item()

                # Generate predictions
                generated_ids = model.generate(
                    input_features,
                    forced_decoder_ids=forced_decoder_ids,
                    max_length=256
                )

                # Decode predictions and labels
                decoded_preds = processor.batch_decode(generated_ids, skip_special_tokens=True)
                labels_clean = torch.where(labels != -100, labels, processor.tokenizer.pad_token_id)
                decoded_labels = processor.batch_decode(labels_clean, skip_special_tokens=True)

                # Log predictions and store metrics by language
                for pred, label, lang in zip(decoded_preds, decoded_labels, languages):
                    pred = pred.strip()
                    label = label.strip()
                    if pred and label:
                        # Store in overall metrics
                        all_preds.append(pred)
                        all_labels.append(label)

                        # Store in language-specific metrics
                        if lang in metrics_by_language:
                            metrics_by_language[lang]["preds"].append(pred)
                            metrics_by_language[lang]["labels"].append(label)

                        # Log detailed samples
                        if samples_logged < log_samples:
                            print("\n--- Sample #{} (Language: {}) ---".format(samples_logged + 1, lang))
                            print(f"Predicted : {pred}")
                            print(f"Actual    : {label}")
                            print(f"WER       : {wer([label], [pred]):.4f}")
                            print(f"CER       : {cer([label], [pred]):.4f}")
                            print("-" * 50)
                            samples_logged += 1

            except Exception as e:
                print(f"Error in batch {batch_idx}: {str(e)}")
                continue

    # Calculate overall metrics
    overall_metrics = calculate_metrics(all_preds, all_labels, total_loss, len(dataloader))

    # Calculate language-specific metrics
    language_metrics = {}
    for lang, data in metrics_by_language.items():
        if data["preds"]:  # Only calculate if we have predictions for this language
            lang_loss = total_loss * (len(data["preds"]) / len(all_preds))  # Approximate loss distribution
            lang_metrics = calculate_metrics(data["preds"], data["labels"], lang_loss, len(dataloader))
            language_metrics[lang] = lang_metrics

    # Print detailed metrics
    print("\n=== Overall Evaluation Metrics ===")
    for metric, value in overall_metrics.items():
        print(f"{metric}: {value:.4f}")

    print("\n=== Language-Specific Metrics ===")
    for lang, metrics in language_metrics.items():
        print(f"\n{lang.upper()} Metrics:")
        for metric, value in metrics.items():
            print(f"{metric}: {value:.4f}")

    # Combine metrics for return
    return {
        "overall": overall_metrics,
        "by_language": language_metrics
    }

In [98]:
def calculate_metrics(all_preds, all_labels, total_loss, num_batches):
    """Separate function for metric calculation with error handling"""
    metrics = {
        "loss": total_loss / num_batches if num_batches > 0 else float('inf')
    }

    if len(all_preds) > 0 and len(all_labels) > 0:
        try:
            metrics["WER"] = wer(all_labels, all_preds)
            metrics["CER"] = cer(all_labels, all_preds)

            tokenized_labels = [text.split() for text in all_labels]
            tokenized_preds = [text.split() for text in all_preds]

            min_len = min(len(tokenized_labels), len(tokenized_preds))
            tokenized_labels = tokenized_labels[:min_len]
            tokenized_preds = tokenized_preds[:min_len]

            flat_labels = [word for sentence in tokenized_labels for word in sentence]
            flat_preds = [word for sentence in tokenized_preds for word in sentence]

            if len(flat_labels) == len(flat_preds) and len(flat_labels) > 0:
                precision, recall, f1, _ = precision_recall_fscore_support(
                    flat_labels,
                    flat_preds,
                    average='weighted',
                    zero_division=0
                )
                accuracy = accuracy_score(flat_labels, flat_preds)

                metrics.update({
                    "Precision": precision,
                    "Recall": recall,
                    "F1-Score": f1,
                    "Accuracy": accuracy
                })
        except Exception as e:
            print(f"Error calculating metrics: {str(e)}")
            metrics.update({
                "WER": 1.0,
                "CER": 1.0,
                "Precision": 0.0,
                "Recall": 0.0,
                "F1-Score": 0.0,
                "Accuracy": 0.0
            })

    return metrics

In [99]:
def train_model(config, checkpoint_path=None):
    # Initialize wandb
    wandb.init(project="taglish-whisper-fine-tuning", config=config)

    # Set device and handle GPU unavailability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not torch.cuda.is_available():
        print("WARNING: CUDA is not available. Training will proceed on CPU, which will be much slower.")
        config["mixed_precision"] = False

    # Load model and processor with safe checkpoint loading
    start_epoch = 0
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        try:
            processor = WhisperProcessor.from_pretrained(checkpoint_path)
            model = WhisperForConditionalGeneration.from_pretrained(checkpoint_path)

            # Add Tagalog language token if not present
            special_tokens = {"additional_special_tokens": ["<|tl|>"]}
            num_added_tokens = processor.tokenizer.add_special_tokens(special_tokens)
            if num_added_tokens > 0:
                model.resize_token_embeddings(len(processor.tokenizer))

            # Safely load training state
            training_state_path = os.path.join(checkpoint_path, 'training_state.pt')
            if os.path.exists(training_state_path):
                try:
                    training_state = torch.load(training_state_path, map_location=device)
                    start_epoch = training_state.get('epoch', 0) + 1
                    print(f"Resuming from epoch {start_epoch}")
                except Exception as e:
                    print(f"Warning: Could not load training state: {e}")
                    start_epoch = 0
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Starting fresh training with multilingual base model")
            processor = WhisperProcessor.from_pretrained("openai/whisper-base")
            model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")

            # Add Tagalog language token
            special_tokens = {"additional_special_tokens": ["<|tl|>"]}
            num_added_tokens = processor.tokenizer.add_special_tokens(special_tokens)
            if num_added_tokens > 0:
                model.resize_token_embeddings(len(processor.tokenizer))
    else:
        print("Starting fresh training with multilingual base model")
        processor = WhisperProcessor.from_pretrained("openai/whisper-base")
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")

        # Add Tagalog language token
        special_tokens = {"additional_special_tokens": ["<|tl|>"]}
        num_added_tokens = processor.tokenizer.add_special_tokens(special_tokens)
        if num_added_tokens > 0:
            model.resize_token_embeddings(len(processor.tokenizer))

    model.to(device)

    # Load data
    print("\nLoading data...")
    audio_files, transcripts, languages, timestamps = load_data(
        config["tsv_file"],
        config["audio_dir"],
        config["max_samples"]
    )

    print(f"Total samples loaded: {len(audio_files)}")

    if len(audio_files) == 0:
        raise ValueError("No audio files loaded. Please check your data paths and file formats.")

    # Create dataloaders with validation size check
    total_samples = len(audio_files)
    min_val_samples = 1  # Minimum number of validation samples

    # Adjust validation split if necessary
    if total_samples * config["validation_split"] < min_val_samples:
        adjusted_split = min_val_samples / total_samples
        print(f"Warning: Adjusting validation split from {config['validation_split']} to {adjusted_split} to ensure at least {min_val_samples} validation sample(s)")
        config["validation_split"] = adjusted_split

    # Create dataloaders with language information
    train_dataloader, val_dataloader = create_dataloaders(
        audio_files,
        transcripts,
        timestamps,
        processor,
        config,
        languages=languages  # Pass languages to create_dataloaders
    )

    print(f"Training batches: {len(train_dataloader)}")
    print(f"Validation batches: {len(val_dataloader)}")


    # Setup optimizer and scheduler
    optimizer = AdamW(
        model.parameters(),
        lr=config["learning_rate"],
        weight_decay=config["weight_decay"]
    )

    num_training_steps = config["epochs"] * len(train_dataloader)
    scheduler = get_scheduler(
        "cosine",
        optimizer=optimizer,
        num_warmup_steps=config["warmup_steps"],
        num_training_steps=num_training_steps
    )

    # Load optimizer and scheduler states if resuming
    if checkpoint_path and os.path.exists(checkpoint_path) and 'training_state' in locals():
        try:
            optimizer.load_state_dict(training_state['optimizer_state_dict'])
            scheduler.load_state_dict(training_state['scheduler_state_dict'])
        except Exception as e:
            print(f"Warning: Could not load optimizer/scheduler states: {e}")

    # Setup mixed precision training only if CUDA is available
    scaler = torch.amp.GradScaler() if config["mixed_precision"] and torch.cuda.is_available() else None

    # Setup early stopping
    early_stopping = EarlyStopping(patience=config["early_stopping_patience"])

    # Training loop
    for epoch in range(start_epoch, config["epochs"]):
        model.train()
        total_loss = 0
        optimizer.zero_grad()

        for i, batch in enumerate(train_dataloader):
            if not batch:
                continue

            input_features = batch["input_features"].to(device)
            labels = batch["labels"].to(device)

            # Get language information from batch if available
            language = batch.get("language", ["tl-en"] * input_features.size(0))  # Default to tl-en if not provided

            # Create forced decoder IDs for both English and Tagalog
            forced_decoder_ids = processor.get_decoder_prompt_ids(
                language=["en", "tl"],  # Specify both languages
                task="transcribe",
                no_timestamps=True
            )

            # Mixed precision training
            if scaler is not None:
                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    outputs = model(
                        input_features,
                        labels=labels,
                        forced_decoder_ids=forced_decoder_ids
                    )
                    loss = outputs.loss / config["gradient_accumulation_steps"]

                scaler.scale(loss).backward()

                if (i + 1) % config["gradient_accumulation_steps"] == 0:
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), config["max_grad_norm"])
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
            else:
                outputs = model(
                    input_features,
                    labels=labels,
                    forced_decoder_ids=forced_decoder_ids
                )
                loss = outputs.loss / config["gradient_accumulation_steps"]
                loss.backward()

                if (i + 1) % config["gradient_accumulation_steps"] == 0:
                    clip_grad_norm_(model.parameters(), config["max_grad_norm"])
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

            total_loss += loss.item()

            # Logging
            if i % 10 == 0:
                wandb.log({
                    "batch_loss": loss.item(),
                    "learning_rate": scheduler.get_last_lr()[0]
                })

        # Validation phase
        val_metrics = evaluate_model(model, processor, val_dataloader, device)

        # Log metrics
        wandb.log({
            "train_loss": total_loss / len(train_dataloader),
            **val_metrics,
            "epoch": epoch + 1
        })

        # Early stopping check
        early_stopping(val_metrics["loss"])
        if early_stopping.should_stop:
            print(f"Early stopping triggered at epoch {epoch + 1}")
            break

        # Save checkpoint
        if (epoch + 1) % config["checkpoint_interval"] == 0:
            checkpoint_dir = f"/content/drive/Shareddrives/CS307-Thesis/Dataset/whisper_checkpoints/checkpoint_epoch_{epoch + 1}"
            os.makedirs(checkpoint_dir, exist_ok=True)

            model.save_pretrained(checkpoint_dir)
            processor.save_pretrained(checkpoint_dir)

            torch.save({
                'epoch': epoch,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'metrics': val_metrics
            }, os.path.join(checkpoint_dir, 'training_state.pt'))

    wandb.finish()

In [None]:
if __name__ == "__main__":
    drive.mount('/content/drive')

    # Verify CUDA installation and GPU availability before starting
    verify_gpu_status()

    # Start training
    checkpoint_path = "/content/drive/Shareddrives/CS307-Thesis/Dataset/whisper_checkpoints/checkpoint_epoch_2"
    train_model(config, checkpoint_path=checkpoint_path)