In [28]:
!pip install torch pandas numpy datasets transformers librosa evaluate jiwer



In [29]:
import os
import torch
import pandas as pd
import numpy as np
import wandb
from datasets import Dataset, Audio
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    WhisperFeatureExtractor,
    WhisperTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import librosa
import evaluate
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from google.colab import drive
import logging
import sys
from torch.utils.data import DataLoader
import json
from datetime import datetime
from torch.nn.utils.rnn import pad_sequence
from itertools import zip_longest


# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('training.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

# Mount Google Drive
drive.mount('/content/drive')

def load_data(tsv_file, audio_dir, max_samples=None):
    """
    Load data from TSV file with timestamp handling, compatible with both "sec" and "min:sec" formats.
    """
    audio_files, transcripts, languages, timestamps = [], [], [], []
    # Read TSV file
    df = pd.read_csv(tsv_file, sep='\t')
    required_columns = ['path', 'start_time', 'end_time', 'language', 'sentence']
    # Verify all required columns are present
    if not all(col in df.columns for col in required_columns):
        raise ValueError(f"TSV file must contain columns: {required_columns}")
    # Shuffle and limit samples if specified
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    if max_samples:
        df = df.head(max_samples)
    for _, row in df.iterrows():
        audio_file = row['path']
        if not audio_file.endswith((".mp3", ".wav", ".flac")):
            print(f"Skipping unsupported file type: {audio_file}")
            continue
        full_audio_path = os.path.join(audio_dir, audio_file)
        if not os.path.exists(full_audio_path):
            print(f"Warning: Audio file not found: {full_audio_path}")
            continue
        # Parse timestamps
        def parse_time(time_str):
            try:
                # Check if time is already in seconds
                return float(time_str)
            except ValueError:
                # Convert from "min:sec" format to seconds
                minutes, seconds = map(float, time_str.split(":"))
                return minutes * 60 + seconds
        try:
            start_time = parse_time(row['start_time'])
            end_time = parse_time(row['end_time'])
        except Exception as e:
            print(f"Error parsing timestamps for {audio_file}: {str(e)}")
            continue
        audio_files.append(full_audio_path)
        transcripts.append(row['sentence'])
        timestamps.append((start_time, end_time))
        languages.append(row['language'])
    return audio_files, transcripts, languages, timestamps

class WhisperTrainer:
    def __init__(self, config):
        """
        Initialize the WhisperTrainer with configuration.

        Args:
            config (dict): Configuration dictionary containing training parameters
        """
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {self.device}")

        # Initialize WandB
        if self.config['use_wandb']:
            wandb.init(
                project=self.config['wandb_project_name'],
                name=self.config['wandb_run_name'],
                config=self.config
            )

        self.setup_model_and_processor()
        self.metrics = {
            'wer': evaluate.load('wer'),
            'cer': evaluate.load('cer')
        }

    def setup_model_and_processor(self):
        """Setup the Whisper model and processor."""
        try:
            self.processor = WhisperProcessor.from_pretrained(self.config['model_name'])
            self.model = WhisperForConditionalGeneration.from_pretrained(self.config['model_name'])
            self.feature_extractor = WhisperFeatureExtractor.from_pretrained(self.config['model_name'])
            self.tokenizer = WhisperTokenizer.from_pretrained(self.config['model_name'])

            # Move model to appropriate device
            self.model = self.model.to(self.device)
            logger.info("Model and processor setup completed successfully")
        except Exception as e:
            logger.error(f"Error in setting up model and processor: {str(e)}")
            raise

    def preprocess_audio(self, audio_path, start_time=None, end_time=None):
        """
        Preprocess audio file with resampling and segmentation.

        Args:
            audio_path (str): Path to audio file
            start_time (float, optional): Start time in seconds
            end_time (float, optional): End time in seconds
        """
        try:
            # Load audio with librosa
            audio, sr = librosa.load(audio_path, sr=None)

            # Resample if necessary
            if sr != 16000:
                audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
                sr = 16000

            # Extract segment if timestamps are provided
            if start_time is not None and end_time is not None:
                start_idx = int(start_time * sr)
                end_idx = int(end_time * sr)
                audio = audio[start_idx:end_idx]

            return audio, sr
        except Exception as e:
            logger.error(f"Error preprocessing audio {audio_path}: {str(e)}")
            return None, None

    def prepare_dataset(self, audio_files, transcripts, languages, timestamps):
        """
        Prepare dataset for training.
        """
        try:
            dataset_dict = {
                "audio": [],
                "text": [],
                "language": [],
                "path": []
            }

            max_length = 0
            for audio_path, transcript, lang, (start, end) in zip(
                audio_files, transcripts, languages, timestamps
            ):
                # Preprocess audio
                audio, sr = self.preprocess_audio(audio_path, start, end)
                if audio is None:
                    continue

                audio_tensor = torch.from_numpy(audio)
                max_length = max(max_length, audio_tensor.shape[0])
                dataset_dict["audio"].append(audio_tensor)
                dataset_dict["text"].append(transcript)
                dataset_dict["language"].append(lang)
                dataset_dict["path"].append(audio_path)

            # Pad the audio tensors
            for i, tensor in enumerate(dataset_dict["audio"]):
                dataset_dict["audio"][i] = torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[0]), mode='constant', value=0)
            return Dataset.from_dict(dataset_dict)
        except Exception as e:
            logger.error(f"Error preparing dataset: {str(e)}")
            raise

    def compute_metrics(self, pred):
        pred_ids = pred.predictions
        label_ids = pred.label_ids

        # Replace -100 with pad token id
        label_ids[label_ids == -100] = self.tokenizer.pad_token_id

        # Decode predictions and references
        predictions = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        references = self.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        # Compute metrics
        wer_scores = []
        cer_scores = []

        for i in range(len(predictions)):
            wer_scores.append(self.metrics['wer'].compute(predictions=[predictions[i]], references=[references[i]]))
            cer_scores.append(self.metrics['cer'].compute(predictions=[predictions[i]], references=[references[i]]))

        wer = sum(wer_scores) / len(wer_scores)
        cer = sum(cer_scores) / len(cer_scores)

        # Log detailed predictions vs references
        for pred, ref, path in zip(predictions, references, pred.input_ids):
            logger.info(f"\nAudio: {path}")
            logger.info(f"Reference: {ref}")
            logger.info(f"Prediction: {pred}")

        return {
            "wer": wer,
            "cer": cer
        }

    def train(self, train_dataset, eval_dataset):
        """
        Train the model with the prepared datasets.
        """
        try:
            training_args = Seq2SeqTrainingArguments(
                output_dir=self.config['output_dir'],
                per_device_train_batch_size=self.config['batch_size'],
                gradient_accumulation_steps=self.config['gradient_accumulation_steps'],
                learning_rate=self.config['learning_rate'],
                warmup_steps=self.config['warmup_steps'],
                max_steps=self.config['max_steps'],
                fp16=torch.cuda.is_available(),
                evaluation_strategy="steps",
                eval_steps=self.config['eval_steps'],
                save_steps=self.config['save_steps'],
                logging_steps=self.config['logging_steps'],
                report_to="wandb" if self.config['use_wandb'] else None,
                load_best_model_at_end=True,
                metric_for_best_model="wer",
                greater_is_better=False
            )

            trainer = Seq2SeqTrainer(
                model=self.model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                compute_metrics=self.compute_metrics,
            )

            # Load checkpoint if it exists
            if os.path.exists(self.config['checkpoint_dir']):
                logger.info(f"Loading checkpoint from {self.config['checkpoint_dir']}")
                trainer.train(resume_from_checkpoint=self.config['checkpoint_dir'])
            else:
                trainer.train()

            # Save final model
            trainer.save_model(self.config['output_dir'])
            logger.info("Training completed successfully")

        except Exception as e:
            logger.error(f"Error during training: {str(e)}")
            raise

    def evaluate_model(self, eval_dataset, split_name="eval"):
        """
        Evaluate the model on a dataset.
        """
        try:
            logger.info(f"Starting evaluation on {split_name} split")

            eval_dataloader = DataLoader(
                eval_dataset,
                batch_size=self.config['eval_batch_size'],
                shuffle=False
            )

            self.model.eval()
            all_metrics = {
                'wer': [], 'cer': [],
                'predictions': [], 'references': [],
                'audio_paths': [], 'timestamps': []
            }

            with torch.no_grad():
                for batch in eval_dataloader:
                    # Generate predictions
                    inputs = self.processor(
                        batch['audio'],
                        sampling_rate=16000,
                        return_tensors="pt"
                    ).to(self.device)

                    generated_ids = self.model.generate(
                        inputs.input_features,
                        max_length=self.config['max_length']
                    )

                    # Decode predictions
                    transcriptions = self.tokenizer.batch_decode(
                        generated_ids,
                        skip_special_tokens=True
                    )

                    # Compute metrics
                    wer = self.metrics['wer'].compute(
                        predictions=transcriptions,
                        references=batch['text']
                    )
                    cer = self.metrics['cer'].compute(
                        predictions=transcriptions,
                        references=batch['text']
                    )

                    # Store results
                    all_metrics['wer'].append(wer)
                    all_metrics['cer'].append(cer)
                    all_metrics['predictions'].extend(transcriptions)
                    all_metrics['references'].extend(batch['text'])
                    all_metrics['audio_paths'].extend(batch['path'])

                    # Log detailed results
                    for pred, ref, path in zip(
                        transcriptions, batch['text'], batch['path']
                    ):
                        logger.info(f"\nAudio: {path}")
                        logger.info(f"Reference: {ref}")
                        logger.info(f"Prediction: {pred}")

            # Compute average metrics
            avg_metrics = {
                'wer': np.mean(all_metrics['wer']),
                'cer': np.mean(all_metrics['cer'])
            }

            # Log to WandB
            if self.config['use_wandb']:
                wandb.log({f"{split_name}_{k}": v for k, v in avg_metrics.items()})

            # Save detailed results
            results_file = os.path.join(
                self.config['output_dir'],
                f"{split_name}_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
            )
            with open(results_file, 'w') as f:
                json.dump(all_metrics, f, indent=2)

            logger.info(f"Evaluation results saved to {results_file}")
            return avg_metrics

        except Exception as e:
            logger.error(f"Error during evaluation: {str(e)}")
            raise

def main():
    # Configuration
    config = {
        'model_name': 'openai/whisper-small',  # The name of the pre-trained Whisper model to use for fine-tuning
        'output_dir': '/content/drive/Shareddrives/CS307-Thesis/Dataset/whisper-taglish/single-speaker/output',  # The directory where the trained model and other artifacts will be saved
        'checkpoint_dir': '/content/drive/Shareddrives/CS307-Thesis/Dataset/whisper-taglish/single-speaker/checkpoints',  # The directory where checkpoints will be saved and loaded from
        'batch_size': 8,  # The batch size for training
        'eval_batch_size': 4,  # The batch size for evaluation
        'gradient_accumulation_steps': 2,  # The number of gradient accumulation steps
        'learning_rate': 1e-5,  # The learning rate for training
        'warmup_steps': 500,  # The number of warmup steps for the learning rate scheduler
        'max_steps': 5000,  # The maximum number of training steps
        'eval_steps': 1000,  # The number of steps between each evaluation
        'save_steps': 1000,  # The number of steps between each checkpoint save
        'logging_steps': 100,  # The number of steps between each logging step
        'max_length': 1000,  # The maximum length of the output transcription
        'use_wandb': True,  # Whether to use Weights & Biases (WandB) for experiment tracking
        'wandb_project_name': 'whisper-taglish',  # The name of the WandB project
        'wandb_run_name': f'whisper-small-taglish-{datetime.now().strftime("%Y%m%d_%H%M%S")}'  # The name of the WandB run, generated with the current timestamp
    }

    # Load data using the provided function
    audio_files, transcripts, languages, timestamps = load_data(
        tsv_file='/content/drive/Shareddrives/CS307-Thesis/Dataset/single-speaker/validated.tsv',
        audio_dir='/content/drive/Shareddrives/CS307-Thesis/Dataset/single-speaker/',
        max_samples=None
    )

    # Initialize trainer
    trainer = WhisperTrainer(config)

    # Prepare datasets
    full_dataset = trainer.prepare_dataset(audio_files, transcripts, languages, timestamps)

    # Split dataset
    train_test_split = full_dataset.train_test_split(test_size=0.2, seed=42)
    train_dataset = train_test_split['train']
    eval_dataset = train_test_split['test']

    # Evaluate pre-training
    logger.info("Evaluating model before training...")
    pre_train_metrics = trainer.evaluate_model(eval_dataset, split_name="pre_training")
    logger.info(f"Pre-training metrics: {pre_train_metrics}")

    # Train model
    logger.info("Starting training...")
    trainer.train(train_dataset, eval_dataset)

    # Evaluate post-training
    logger.info("Evaluating model after training...")
    post_train_metrics = trainer.evaluate_model(eval_dataset, split_name="post_training")
    logger.info(f"Post-training metrics: {post_train_metrics}")

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


ERROR:__main__:Error during evaluation: operands could not be broadcast together with remapped shapes [original->remapped]: (2,2)  and requested shape (3,2)


ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (2,2)  and requested shape (3,2)