In [None]:
import torch
import string
import numpy as np
import os 
import soundfile as sf
import librosa
import csv
from datasets import load_dataset, Dataset
from evaluate import load as load_metric 
from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    TrainerCallback 
)
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

# --- CUSTOM CALLBACK FOR LOGGING RESULTS ---
class LoggingCallback(TrainerCallback):
    """A custom callback to capture and store training and evaluation metrics."""
    def __init__(self):
        super().__init__()
        self.results = []
        self.header = ['epoch', 'step', 'train_loss', 'eval_loss', 'eval_wer', 'runtime']

    def on_log(self, args, state, control, logs=None, **kwargs):
        """Called when trainer.log() is called (for training metrics)."""
        if logs is not None and state.is_local_process_zero:
            if 'loss' in logs:
                self.results.append({
                    'epoch': round(state.epoch, 2),
                    'step': state.global_step,
                    'train_loss': logs.get('loss'),
                    'eval_loss': None,
                    'eval_wer': None,
                    'runtime': None,
                })

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """Called when evaluation results are available."""
        if metrics is not None and state.is_local_process_zero:
            self.results.append({
                'epoch': round(state.epoch, 2),
                'step': state.global_step,
                'train_loss': None,
                'eval_loss': metrics.get('eval_loss'),
                'eval_wer': metrics.get('eval_wer'),
                'runtime': metrics.get('eval_runtime'),
            })

# --- FFmpeg PATH Setup ---
FFMPEG_BIN_PATH = r"C:\ffmpeg\bin" 

if os.path.isdir(FFMPEG_BIN_PATH):
    os.environ["PATH"] = FFMPEG_BIN_PATH + os.pathsep + os.environ["PATH"]
    print(f"‚úì FFmpeg path injected: {FFMPEG_BIN_PATH}")
else:
    print(f"‚ö† Warning: FFmpeg bin path '{FFMPEG_BIN_PATH}' not found. Relying on System PATH.")

# --- Metric Setup ---
try:
    wer_metric = load_metric("wer") 
    print("‚úì WER metric loaded successfully")
except Exception as e:
    print(f"‚ö† Warning: Could not load WER metric. Error: {e}")
    wer_metric = None

# Global processor reference
processor = None 

def compute_metrics(pred):
    """Calculates the Word Error Rate (WER) for evaluation."""
    if wer_metric is None or processor is None:
        return {"wer": 99.0}

    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    # Replace padding with processor's pad token ID
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode predictions and labels
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    # Compute WER
    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


# --- Custom Data Collator for CTC ---
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad audio input features
        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        
        # Pad text labels
        with self.processor.as_target_processor():
            labels_batch = self.processor.tokenizer.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )
        
        # Replace padding with -100 for loss calculation
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        
        # Ensure attention mask is present
        if "attention_mask" not in batch:
             batch["attention_mask"] = torch.ones_like(batch["input_values"], dtype=torch.long)
             
        return batch


def load_audio_file(audio_path, target_sr=16000):
    """Load audio file using soundfile/librosa directly."""
    try:
        # Try soundfile first (faster)
        audio_array, sample_rate = sf.read(audio_path)
        
        # Resample if needed
        if sample_rate != target_sr:
            # Handle multi-dimensional array (stereo)
            if len(audio_array.shape) > 1 and audio_array.shape[1] > 1:
                audio_array = audio_array.mean(axis=1)
            
            audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sr)
        
        # Convert to mono if still stereo after potential resampling
        if len(audio_array.shape) > 1:
            audio_array = audio_array.mean(axis=1)
            
        # librosa expects float32
        if audio_array.dtype != np.float32:
            audio_array = audio_array.astype(np.float32)
            
        return audio_array, target_sr
    except Exception as e:
        print(f"‚ö† Error loading {audio_path}: {e}")
        # Return silence as fallback (1 second of silence)
        return np.zeros(target_sr, dtype=np.float32), target_sr

def write_results_to_csv(results_list, output_dir, header):
    """Writes the collected metrics to a CSV file."""
    csv_path = os.path.join(output_dir, "training_results.csv")
    print(f"\nüíæ Saving results to: {csv_path}")
    
    try:
        os.makedirs(output_dir, exist_ok=True)
        with open(csv_path, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=header)
            writer.writeheader()
            for row in results_list:
                filtered_row = {k: v if v is not None else '' for k, v in row.items()}
                writer.writerow(filtered_row)
        print(f"‚úì Successfully saved {len(results_list)} metric entries.")
    except Exception as e:
        print(f"‚ùå Failed to write CSV file: {e}")


def evaluate_and_save_predictions(trainer, eval_dataset, eval_data_raw, output_dir):
    """Run evaluation and save detailed predictions to CSV."""
    print("\n" + "="*60)
    print("üìä Running Final Evaluation and Generating Predictions CSV")
    print("="*60 + "\n")
    
    try:
        # Get predictions
        predictions = trainer.predict(eval_dataset)
        pred_logits = predictions.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)
        
        # Decode predictions
        pred_str = processor.batch_decode(pred_ids)
        
        # Get original paths and transcriptions
        audio_paths = list(eval_data_raw["audio_path"])
        ground_truth = list(eval_data_raw["text"])
        
        # Ensure lengths match
        min_len = min(len(audio_paths), len(ground_truth), len(pred_str))
        if min_len < len(audio_paths):
            print(f"‚ö† Warning: Length mismatch detected. Using first {min_len} samples.")
            audio_paths = audio_paths[:min_len]
            ground_truth = ground_truth[:min_len]
            pred_str = pred_str[:min_len]
        
        # Clean ground truth the same way as in training
        chars_to_remove_regex = string.punctuation
        def remove_special_characters(text):
            if text is None: return ""
            text = str(text).lower()
            return text.translate(str.maketrans('', '', chars_to_remove_regex))
        
        ground_truth_cleaned = [remove_special_characters(text) for text in ground_truth]
        
        # Calculate individual WERs
        individual_wers = []
        for pred, ref in zip(pred_str, ground_truth_cleaned):
            try:
                if wer_metric is not None and pred and ref:
                    wer = wer_metric.compute(predictions=[pred], references=[ref])
                    individual_wers.append(round(wer, 4) if wer is not None else None)
                else:
                    individual_wers.append(None)
            except Exception as e:
                print(f"‚ö† Warning: WER calculation failed for one sample: {e}")
                individual_wers.append(None)
        
        # Create results CSV
        results_data = []
        for i in range(len(audio_paths)):
            results_data.append({
                'path': str(audio_paths[i]) if audio_paths[i] is not None else "",
                'ground_truth': str(ground_truth_cleaned[i]) if ground_truth_cleaned[i] is not None else "",
                'prediction': str(pred_str[i]) if pred_str[i] is not None else "",
                'wer': individual_wers[i] if individual_wers[i] is not None else ""
            })
        
        # Save to CSV
        csv_path = os.path.join(output_dir, "evaluation_predictions.csv")
        print(f"üíæ Saving evaluation predictions to: {csv_path}")
        
        with open(csv_path, 'w', newline='', encoding='utf-8') as f:
            writer = csv.DictWriter(f, fieldnames=['path', 'ground_truth', 'prediction', 'wer'])
            writer.writeheader()
            writer.writerows(results_data)
        print(f"‚úì Successfully saved {len(results_data)} predictions.")
        
        # Print summary statistics
        valid_wers = [w for w in individual_wers if w is not None and isinstance(w, (int, float))]
        if valid_wers:
            avg_wer = sum(valid_wers) / len(valid_wers)
            print(f"\nüìà Evaluation Summary:")
            print(f"  Average WER: {avg_wer:.4f}")
            print(f"  Best WER: {min(valid_wers):.4f}")
            print(f"  Worst WER: {max(valid_wers):.4f}")
            print(f"  Samples evaluated: {len(valid_wers)}/{len(results_data)}")
        else:
            print("‚ö† Warning: No valid WER scores calculated.")
            
    except Exception as e:
        print(f"‚ùå Error during evaluation predictions: {e}")
        print("‚ö† Continuing without evaluation predictions CSV...")
        import traceback
        traceback.print_exc()
        

# --- Main Fine-Tuning Function ---
def run_wav2vec2_finetune(output_dir: str = "./wav2vec2-gpu-finetune-model"):
    
    print("\n" + "="*60)
    print("üöÄ Starting Wav2Vec2 Fine-tuning Pipeline")
    print("="*60 + "\n")
    
    # Check device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"üñ•Ô∏è  Device: {device.upper()}")
    if device == "cpu":
        print("‚ö†Ô∏è  Warning: CUDA not found. Training will default to CPU.")
    
    print()
    
    # --- File Paths ---
    # NOTE: Update these paths!
    DATA_FILE_PATH = r"C:\Users\18jvo\Desktop\ASR_Local\new_audio_paths.csv"
    EVAL_DATA_FILE_PATH = r"C:\Users\18jvo\eval_pathfinal_file.csv"
    
    AUDIO_COLUMN_NAME = "path"
    TEXT_COLUMN_NAME = "transcription"

    # Validate file paths
    if not os.path.exists(DATA_FILE_PATH):
        raise FileNotFoundError(f"Training data not found: {DATA_FILE_PATH}")
    if not os.path.exists(EVAL_DATA_FILE_PATH):
        raise FileNotFoundError(f"Evaluation data not found: {EVAL_DATA_FILE_PATH}")
    
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"‚úì Data files found:\n  Train: {DATA_FILE_PATH}\n  Eval:  {EVAL_DATA_FILE_PATH}\n")

    # Load datasets
    print("üìä Loading datasets...")
    try:
        full_dataset = load_dataset("csv", 
            data_files={"train": DATA_FILE_PATH, "eval": EVAL_DATA_FILE_PATH}
        )
        train_data_raw = full_dataset["train"]
        val_data_raw = full_dataset["eval"]
        
        # NOTE: Using len(train_data_raw) is 324 based on user input
        print(f"‚úì Loaded {len(train_data_raw)} training samples, {len(val_data_raw)} evaluation samples\n")
    except Exception as e:
        raise RuntimeError(f"Failed to load datasets: {e}")
    
    # --- Data Preparation ---
    train_data_raw = train_data_raw.rename_column(AUDIO_COLUMN_NAME, "audio_path").rename_column(TEXT_COLUMN_NAME, "text")
    val_data_raw = val_data_raw.rename_column(AUDIO_COLUMN_NAME, "audio_path").rename_column(TEXT_COLUMN_NAME, "text")

    print("üìù Creating vocabulary...")
    
    chars_to_remove_regex = string.punctuation
    
    def remove_special_characters(text):
        """Lowercase and remove punctuation."""
        if text is None: return ""
        text = str(text).lower()
        return text.translate(str.maketrans('', '', chars_to_remove_regex))
        
    train_text = list(train_data_raw["text"])
    eval_text = list(val_data_raw["text"])
    all_text = " ".join([str(t) for t in train_text + eval_text if t is not None])
    cleaned_text = remove_special_characters(all_text)
    
    # Get unique characters
    vocab_list = list(set(cleaned_text.replace(' ', '')))
    
    # Build vocabulary dictionary
    vocab_dict = {v: i for i, v in enumerate(["|"] + sorted(vocab_list))}
    vocab_dict["[UNK]"] = len(vocab_dict)
    vocab_dict["[PAD]"] = len(vocab_dict)
    
    print(f"‚úì Vocabulary created with {len(vocab_dict)} tokens")

    # --- Initialize Processor and Model ---
    print("ü§ñ Initializing model and processor...")
    global processor
    processor = Wav2Vec2Processor.from_pretrained(
        "facebook/wav2vec2-base", 
        unk_token="[UNK]", 
        pad_token="[PAD]", 
        word_delimiter_token="|",
        vocab_dict=vocab_dict
    )
    
    model = Wav2Vec2ForCTC.from_pretrained(
        "facebook/wav2vec2-base",
        ctc_loss_reduction="mean",
        pad_token_id=processor.tokenizer.pad_token_id,
        vocab_size=len(processor.tokenizer)
    )
    
    model.freeze_feature_encoder()
    model.to(device) # Move model to device
    
    print("‚úì Model and processor initialized\n")
    
    data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

    # --- Preprocess Dataset ---
    def prepare_dataset(batch):
        """Preprocess audio and text for training using direct file loading."""
        audio_path = batch["audio_path"]
        audio_array, sample_rate = load_audio_file(audio_path, target_sr=16000)
        
        # Process audio
        batch["input_values"] = processor(audio_array, sampling_rate=sample_rate).input_values[0]
        
        # Process text
        text = batch["text"]
        if text is None: text = ""
        text = str(text).strip()
        text = remove_special_characters(text)
        
        # Manual character-level tokenization
        if text == "":
            label_ids = [processor.tokenizer.encoder["|"]]
        else:
            text = text.replace(" ", "|")
            label_ids = []
            for char in text:
                if char in processor.tokenizer.encoder:
                    label_ids.append(processor.tokenizer.encoder[char])
                # Skip characters not in vocab
                # Note: The original code handled spaces by replacing them with '|'
            
        if len(label_ids) == 0:
            label_ids = [processor.tokenizer.encoder["|"]]
            
        batch["labels"] = label_ids
        return batch

    print("‚öôÔ∏è Preprocessing datasets (loading audio files directly)...")
    try:
        processed_train_data = train_data_raw.map(
            prepare_dataset, 
            remove_columns=train_data_raw.column_names,
            num_proc=1, # Single process for stability
            load_from_cache_file=False # CRITICAL: Disable cache
        )
        processed_val_data = val_data_raw.map(
            prepare_dataset, 
            remove_columns=val_data_raw.column_names,
            num_proc=1, # Single process for stability
            load_from_cache_file=False # CRITICAL: Disable cache
        )
        print("‚úì Data preprocessed\n")
    except Exception as e:
        print(f"‚ùå Preprocessing failed: {e}")
        raise

    # --- Training Configuration (GPU Optimized) ---
    print("‚öôÔ∏è Configuring training parameters (GPU Optimized)...")
    training_args = TrainingArguments(
        output_dir=output_dir,
        # HYPERPARAMETERS OPTIMIZED FOR GPU
        per_device_train_batch_size=16,          # Increased for GPU memory
        per_device_eval_batch_size=16,           
        gradient_accumulation_steps=2,           # Simulates a large batch size of 32 (16 * 2)
        learning_rate=1e-4,                      # Recommended fine-tuning LR
        num_train_epochs=20,                     # Safety cap, relying on Early Stopping
        
        # General Settings
        logging_steps=50,                        
        save_steps=50, # Save checkpoint more often due to short epochs
        evaluation_strategy="steps",             
        eval_steps=50, # Evaluate more often
        save_strategy="steps",
        save_total_limit=2,                      
        metric_for_best_model="wer",
        load_best_model_at_end=True,             
        
        # GPU Specific Settings
        fp16=True,                               # Enable 16-bit precision for speed
        bf16=False,                              
        use_cpu=False,                           # Ensure CUDA is used if available
        dataloader_num_workers=2,                # Increase if I/O is a bottleneck
        report_to="none",
        greater_is_better=False,                 
    )
    
    print("‚úì GPU Optimized Configuration:")
    print(f"  Effective Batch Size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
    print(f"  Learning Rate: {training_args.learning_rate}")
    print(f"  FP16 Enabled: {training_args.fp16}")

    # Early stopping callback and NEW logging callback
    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=5,
        early_stopping_threshold=0.01 # Stop if WER improvement is less than 1%
    )
    logging_callback = LoggingCallback()

    # --- Training ---
    print("="*60)
    print("üéØ Starting Training on GPU")
    print("="*60 + "\n")
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=processed_train_data,
        eval_dataset=processed_val_data,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        callbacks=[early_stopping_callback, logging_callback] 
    )
    
    trainer.train()
    
    print("\n" + "="*60)
    print("‚úÖ Training Complete!")
    print("="*60)
    
    # --- RESULTS SAVING ---
    write_results_to_csv(logging_callback.results, output_dir, logging_callback.header)
    
    # Generate evaluation predictions CSV (wrapped in try-except to not break the pipeline)
    try:
        evaluate_and_save_predictions(trainer, processed_val_data, val_data_raw, output_dir)
    except Exception as e:
        print(f"‚ö† Warning: Could not generate evaluation predictions CSV: {e}")
        print("Training completed successfully, but predictions CSV was not created.")
    # --- END RESULTS SAVING ---
    
    print(f"\nBest model saved to: {output_dir}")

# --- Main Execution ---
if __name__ == "__main__":
    output_path = os.environ.get("WAV2VEC2_OUTPUT_DIR", "./wav2vec2-gpu-finetune-model")
    
    try:
        run_wav2vec2_finetune(output_dir=output_path) 
        print(f"\nüéâ Success! Model and results saved in: {output_path}")
    except Exception as e:
        print(f"\n‚ùå FATAL ERROR: {e}")
        import traceback
        traceback.print_exc()