In [None]:
# %% [markdown]
# # Whisper Sanskrit Fine-tuning Notebook
# 
# This notebook provides an interactive way to fine-tune Whisper on Sanskrit audio data.

# %% [markdown]
# ## 1. Setup and Imports

# %%
# Import all necessary modules
import os
import sys
from pathlib import Path

# Import our custom modules
from config import Config
from audio_preprocessor import AudioPreprocessor
from data_utils import (
    parse_transcript_file,
    prepare_data_splits,
    DatasetProcessor,
    combine_transcript_with_audio
)
from training_utils import (
    WhisperTrainingCallback,
    MetricsComputer,
    create_data_collator
)
from inference import WhisperSanskritTranscriber, evaluate_on_test_set

# Transformers imports
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# %% [markdown]
# ## 2. Configuration

# %%
# Initialize configuration
config = Config()

# Modify configuration as needed
config.model_name = "openai/whisper-small"  # Change to tiny/base/small/medium/large
config.batch_size = 16  # Adjust based on your GPU memory
config.max_steps = 4000  # Adjust based on dataset size
config.learning_rate = 1e-5

# Display configuration
print("Configuration:")
for key, value in config.__dict__.items():
    print(f"  {key}: {value}")

# %% [markdown]
# ## 3. Data Preparation

# %%
# Parse transcript file
print("Parsing transcript file...")
transcript_data = parse_transcript_file(config.transcript_file)
print(f"Found {len(transcript_data)} transcript entries")

# Display sample entries
print("\nSample entries:")
for i, item in enumerate(transcript_data[:3]):
    print(f"\n{i+1}. Audio: {item['audio_filename']}")
    print(f"   Transcript: {item['transcript'][:100]}...")

# %% [markdown]
# ## 4. Audio Preprocessing

# %%
# Initialize audio preprocessor
audio_preprocessor = AudioPreprocessor(config.cache_dir, config.sample_rate)

# Convert audio files
print("Converting audio files from M4A to WAV...")
audio_files = [item['audio_filename'] for item in transcript_data]
converted_paths = audio_preprocessor.convert_audio_batch(audio_files, config.audio_dir)

print(f"\nSuccessfully converted {len(converted_paths)} audio files")

# %% [markdown]
# ## 5. Combine Data and Create Splits

# %%
# Combine transcript data with converted audio paths
combined_data = combine_transcript_with_audio(transcript_data, converted_paths)

# Create train/val/test splits
train_data, val_data, test_data = prepare_data_splits(combined_data, config)

# %% [markdown]
# ## 6. Initialize Model and Processor

# %%
# Load Whisper model and processor
print(f"Loading Whisper model: {config.model_name}")
processor = WhisperProcessor.from_pretrained(config.model_name)
model = WhisperForConditionalGeneration.from_pretrained(config.model_name)

# Configure for Sanskrit
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.generation_config.language = config.language
model.generation_config.task = config.task

# Move to device
model.to(config.device)
print(f"Model loaded on {config.device}")
print(f"Total parameters: {model.num_parameters():,}")

# %% [markdown]
# ## 7. Create Datasets

# %%
# Initialize dataset processor
dataset_processor = DatasetProcessor(processor, config.sample_rate)

# Create datasets
print("Creating datasets...")
train_dataset = dataset_processor.prepare_dataset(
    train_data, audio_preprocessor, config.max_duration
)
val_dataset = dataset_processor.prepare_dataset(
    val_data, audio_preprocessor, config.max_duration
)
test_dataset = dataset_processor.prepare_dataset(
    test_data, audio_preprocessor, config.max_duration
)

print(f"\nDataset sizes:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Validation: {len(val_dataset)} samples")
print(f"  Test: {len(test_dataset)} samples")

# %% [markdown]
# ## 8. Setup Training

# %%
# Create training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=config.output_dir,
    per_device_train_batch_size=config.batch_size,
    per_device_eval_batch_size=config.batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    learning_rate=config.learning_rate,
    warmup_steps=config.warmup_steps,
    max_steps=config.max_steps,
    gradient_checkpointing=config.gradient_checkpointing,
    fp16=config.fp16,
    evaluation_strategy="steps",
    eval_steps=config.eval_steps,
    save_steps=config.save_steps,
    logging_steps=config.logging_steps,
    report_to=["tensorboard"],
    load_best_model_at_end=config.load_best_model_at_end,
    metric_for_best_model=config.metric_for_best_model,
    greater_is_better=config.greater_is_better,
    predict_with_generate=True,
    generation_max_length=config.generation_max_length,
    save_total_limit=config.save_total_limit,
    remove_unused_columns=False,
)

# Create data collator
data_collator = create_data_collator(processor, model)

# Create metrics computer
compute_metrics = MetricsComputer(processor)

# Create trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    callbacks=[WhisperTrainingCallback()],
)

print("Trainer initialized!")

# %% [markdown]
# ## 9. Start Training
# 
# **Note**: Training will take time depending on your dataset size and GPU.

# %%
# Start training
print("Starting training...")
print(f"Total steps: {config.max_steps}")
print(f"Evaluation every {config.eval_steps} steps")

# Train the model
train_result = trainer.train()

# Save the model
trainer.save_model()
processor.save_pretrained(config.output_dir)

print(f"\nModel saved to: {config.output_dir}")

# %% [markdown]
# ## 10. Evaluate on Test Set

# %%
# Evaluate on test set
print("Evaluating on test set...")
test_results = trainer.evaluate(eval_dataset=test_dataset)

print("\nTest Results:")
for key, value in test_results.items():
    print(f"  {key}: {value:.4f}")

# %% [markdown]
# ## 11. Test the Fine-tuned Model

# %%
# Initialize transcriber with fine-tuned model
transcriber = WhisperSanskritTranscriber(config.output_dir)

# Test on a few examples
print("Testing fine-tuned model:")
print("-" * 80)

for i, item in enumerate(test_data[:5]):
    print(f"\nExample {i+1}:")
    print(f"Audio: {item['original_filename']}")
    
    # Transcribe
    result = transcriber.transcribe(item['audio'])
    
    print(f"True: {item['transcript']}")
    print(f"Pred: {result['text']}")
    print(f"Duration: {result['duration']:.2f}s")

# %% [markdown]
# ## 12. Detailed Evaluation

# %%
# Perform detailed evaluation
detailed_results = evaluate_on_test_set(
    transcriber,
    test_data,
    os.path.join(config.output_dir, "detailed_test_results.json")
)

print("\nDetailed Evaluation Metrics:")
for key, value in detailed_results.items():
    print(f"  {key}: {value:.4f}")

# %% [markdown]
# ## 13. Interactive Transcription
# 
# Use this cell to transcribe any audio file with your fine-tuned model.

# %%
# Example: Transcribe a new audio file
def transcribe_audio(audio_path: str):
    """Transcribe any audio file"""
    result = transcriber.transcribe(audio_path)
    print(f"Audio: {audio_path}")
    print(f"Transcription: {result['text']}")
    print(f"Duration: {result['duration']:.2f}s")
    return result

# Uncomment to use:
# transcribe_audio("path/to/your/audio.m4a")

# %% [markdown]
# ## 14. Save Training Summary

# %%
import json

# Save training summary
summary = {
    "model_name": config.model_name,
    "language": config.language,
    "train_samples": len(train_dataset),
    "val_samples": len(val_dataset),
    "test_samples": len(test_dataset),
    "final_wer": test_results.get("eval_wer", None),
    "training_steps": train_result.global_step,
    "device": str(config.device)
}

summary_path = os.path.join(config.output_dir, "training_summary.json")
with open(summary_path, "w") as f:
    json.dump(summary, f, indent=2)

print("Training Summary:")
for key, value in summary.items():
    print(f"  {key}: {value}")

# %% [markdown]
# ## 15. Next Steps
# 
# 1. **Use the model**: Load it with `WhisperSanskritTranscriber("./whisper_sanskrit_finetuned")`
# 2. **Share the model**: Upload to Hugging Face Hub
# 3. **Create API**: Wrap in FastAPI for serving
# 4. **Optimize**: Convert to ONNX for faster inference
# 5. **Improve**: Collect more data and retrain