In [None]:
"""
NOTEBOOK: Whisper Evaluation & Dataset Creation
√âtape 3: Calculer le WER avec Whisper baseline
√âtape 4: Filtrer les voix d'enfants
√âtape 5: Cr√©er le dataset d'entra√Ænement
"""

import json
from pathlib import Path
from typing import List, Dict, Tuple
from dataclasses import dataclass
from extract_wor import WorSegment
from children_voice_filter import ChildrenVoiceFilter, SegmentFilter
import whisper
from jiwer import wer as compute_wer


# ============================================================================
# ZONE 1: WER CALCULATION - Calculer le WER avec Whisper baseline
# ============================================================================

@dataclass
class TranscriptionResult:
    """R√©sultat de transcription d'un segment"""
    segment_id: str
    speaker: str
    file_name: str
    audio_path: str
    ground_truth: str
    whisper_prediction: str
    duration_ms: int
    wer: float
    confidence: float = 0.0


class WhisperBaseline:
    """Transcrire avec Whisper et calculer le WER"""
    
    def __init__(self, model_name: str = "base"):
        """
        Initialiser Whisper
        
        Args:
            model_name: Mod√®le √† utiliser (tiny, base, small, medium, large)
        """
        print(f"üì¶ Chargement Whisper '{model_name}'...")
        self.model = whisper.load_model(model_name)
        print("‚úì Mod√®le charg√©\n")
    
    def transcribe_segment(self, audio_path: Path, language: str = "fr") -> Dict:
        """
        Transcrire un fichier audio
        
        Args:
            audio_path: Chemin du fichier audio
            language: Code langue (fr, en, etc.)
        
        Returns:
            Dict avec 'text' et 'confidence'
        """
        if not audio_path.exists():
            return {"text": "", "confidence": 0.0}
        
        try:
            result = self.model.transcribe(
                str(audio_path),
                language=language,
                verbose=False
            )
            return {
                "text": result["text"].strip(),
                "confidence": result.get("confidence", 0.0)
            }
        except Exception as e:
            print(f"‚ö†Ô∏è  Erreur: {audio_path.name} - {e}")
            return {"text": "", "confidence": 0.0}
    
    @staticmethod
    def calculate_wer(ground_truth: str, prediction: str) -> float:
        """
        Calculer le Word Error Rate
        
        Args:
            ground_truth: Texte de r√©f√©rence
            prediction: Texte pr√©dit par Whisper
        
        Returns:
            WER entre 0.0 (parfait) et 1.0+ (tr√®s mauvais)
        """
        if not ground_truth.strip():
            return 0.0 if not prediction.strip() else 1.0
        
        return compute_wer(ground_truth, prediction)


class BaselineEvaluator:
    """√âvaluer Whisper baseline sur les segments audio"""
    
    def __init__(self, whisper_model: str = "base"):
        self.whisper = WhisperBaseline(whisper_model)
    
    def evaluate(self, audio_segments: List[Dict], sample_size: int = None, 
                batch_size: int = 50) -> List[TranscriptionResult]:
        """
        √âvaluer Whisper sur tous les segments audio
        
        Args:
            audio_segments: R√©sultats de AudioSegmenter.segment_all()
            sample_size: √âvaluer seulement N segments (None = tous)
            batch_size: Afficher progress tous les N segments
        
        Returns:
            Liste de TranscriptionResult
        """
        
        if sample_size:
            audio_segments = audio_segments[:sample_size]
        
        results = []
        
        print(f"üé§ √âvaluation Whisper baseline ({len(audio_segments)} segments)\n")
        
        for i, segment in enumerate(audio_segments):
            audio_path = Path(segment["audio_path"])
            
            # Transcrire
            transcription = self.whisper.transcribe_segment(audio_path)
            
            # Calculer WER
            wer = self.whisper.calculate_wer(
                segment["text"],
                transcription["text"]
            )
            
            # Cr√©er r√©sultat
            result = TranscriptionResult(
                segment_id=segment["segment_id"],
                speaker=segment["speaker"],
                file_name=segment["file_name"],
                audio_path=segment["audio_path"],
                ground_truth=segment["text"],
                whisper_prediction=transcription["text"],
                duration_ms=segment["duration_ms"],
                wer=wer,
                confidence=transcription["confidence"]
            )
            results.append(result)
            
            # Progress
            if (i + 1) % batch_size == 0:
                avg_wer = sum(r.wer for r in results) / len(results)
                print(f"  ‚úì {i + 1}/{len(audio_segments)} | Avg WER: {avg_wer:.3f}")
        
        return results


def print_wer_statistics(results: List[TranscriptionResult]):
    """Afficher les statistiques WER"""
    
    if not results:
        print("‚ùå Aucun r√©sultat")
        return
    
    wers = [r.wer for r in results]
    
    print("\n" + "="*70)
    print("üìä WER STATISTICS")
    print("="*70)
    
    print(f"\nüìà Globales:")
    print(f"   Total segments:    {len(results)}")
    print(f"   WER moyen:         {sum(wers) / len(wers):.3f}")
    print(f"   WER min:           {min(wers):.3f}")
    print(f"   WER max:           {max(wers):.3f}")
    print(f"   WER median:        {sorted(wers)[len(wers)//2]:.3f}")
    
    # Par range
    print(f"\n   Distribution:")
    ranges = [(0.0, 0.1), (0.1, 0.3), (0.3, 0.5), (0.5, 1.0)]
    for low, high in ranges:
        count = sum(1 for w in wers if low <= w < high)
        pct = (count / len(wers)) * 100
        print(f"      {low:.1f}-{high:.1f}: {count:4d} ({pct:5.1f}%)")
    
    # Par speaker
    by_speaker = {}
    for r in results:
        by_speaker.setdefault(r.speaker, []).append(r.wer)
    
    print(f"\nüë• Par speaker:")
    for speaker in sorted(by_speaker.keys()):
        wers_speaker = by_speaker[speaker]
        avg = sum(wers_speaker) / len(wers_speaker)
        print(f"      {speaker:15} {len(wers_speaker):4d} segments | WER: {avg:.3f}")
    
    print("\n" + "="*70 + "\n")


# ============================================================================
# ZONE 2: CHILDREN FILTERING - Filtrer pour garder uniquement les enfants
# ============================================================================

def filter_children_results(results: List[TranscriptionResult], 
                           speakers_info: Dict[str, str]) -> List[TranscriptionResult]:
    """
    Filtrer les r√©sultats pour ne garder que les enfants
    
    Args:
        results: R√©sultats de BaselineEvaluator.evaluate()
        speakers_info: Dict {speaker_name: role} du .cha
    
    Returns:
        R√©sultats filtr√©s aux enfants seulement
    """
    
    children = []
    adults = []
    unknown = []
    
    for r in results:
        role = speakers_info.get(r.speaker, "Unknown")
        
        if role in ChildrenVoiceFilter.CHILD_ROLES:
            children.append(r)
        elif role in ChildrenVoiceFilter.ADULT_ROLES:
            adults.append(r)
        else:
            unknown.append(r)
    
    return children, adults, unknown


def print_filtering_report(children: List[TranscriptionResult],
                          adults: List[TranscriptionResult],
                          unknown: List[TranscriptionResult]):
    """Afficher le rapport de filtrage"""
    
    total = len(children) + len(adults) + len(unknown)
    
    print("\n" + "="*70)
    print("üßí CHILDREN FILTERING REPORT")
    print("="*70)
    
    print(f"\nüìä Distribution:")
    print(f"   Total segments:        {total}")
    print(f"   ‚úÖ Enfants:            {len(children)} ({len(children)/total*100:.1f}%)")
    print(f"   ‚ùå Adultes:            {len(adults)} ({len(adults)/total*100:.1f}%)")
    print(f"   ‚ùì R√¥le inconnu:       {len(unknown)} ({len(unknown)/total*100:.1f}%)")
    
    if children:
        wers_children = [r.wer for r in children]
        print(f"\nüë®‚Äçüë©‚Äçüëß Enfants:")
        print(f"   Segments:              {len(children)}")
        print(f"   Dur√©e audio:           {sum(r.duration_ms for r in children) / 1000 / 60:.1f} minutes")
        print(f"   WER moyen:             {sum(wers_children) / len(wers_children):.3f}")
        
        # Par speaker enfant
        by_speaker = {}
        for r in children:
            by_speaker.setdefault(r.speaker, []).append(r)
        
        print(f"\n   Par speaker ({len(by_speaker)}):")
        for speaker in sorted(by_speaker.keys()):
            segs = by_speaker[speaker]
            wer_avg = sum(r.wer for r in segs) / len(segs)
            duration = sum(r.duration_ms for r in segs) / 1000
            print(f"      {speaker:15} {len(segs):4d} segments | {duration:7.1f}s | WER: {wer_avg:.3f}")
    
    print("\n" + "="*70 + "\n")


# ============================================================================
# ZONE 3: DATASET CREATION - Cr√©er le dataset pour fine-tuning
# ============================================================================

class TrainingDatasetBuilder:
    """Cr√©er les splits train/test pour fine-tuning"""
    
    def __init__(self, output_dir: Path):
        """
        Initialiser le builder
        
        Args:
            output_dir: Dossier o√π sauvegarder les datasets
        """
        self.output_dir = output_dir
        self.output_dir.mkdir(parents=True, exist_ok=True)
    
    def create_splits(self, results: List[TranscriptionResult], 
                     train_ratio: float = 0.8,
                     min_wer: float = None,
                     max_wer: float = None) -> Dict:
        """
        Cr√©er les splits train/test
        
        Args:
            results: R√©sultats filtr√©s (enfants seulement)
            train_ratio: Ratio train/test (0.8 = 80% train, 20% test)
            min_wer: WER minimum √† accepter (None = pas de limite)
            max_wer: WER maximum √† accepter (None = pas de limite)
        
        Returns:
            Dict avec train/test datasets
        """
        
        # Filtrer par WER si sp√©cifi√©
        filtered = results
        if min_wer is not None or max_wer is not None:
            filtered = [
                r for r in results
                if (min_wer is None or r.wer >= min_wer) and
                   (max_wer is None or r.wer <= max_wer)
            ]
        
        print(f"üìä Creating splits from {len(filtered)} segments")
        print(f"   WER range: [{min_wer or '0'}, {max_wer or 'inf'}]")
        
        # Split
        split_idx = int(len(filtered) * train_ratio)
        train = filtered[:split_idx]
        test = filtered[split_idx:]
        
        return {
            "train": train,
            "test": test,
            "total": len(filtered)
        }
    
    def save_jsonl(self, results: List[TranscriptionResult], output_file: Path, 
                  data_type: str = "training"):
        """
        Sauvegarder en format JSONL pour Whisper
        
        Args:
            results: TranscriptionResult √† sauvegarder
            output_file: Chemin du fichier JSONL
            data_type: "training" ou "evaluation"
        """
        
        with open(output_file, "w", encoding="utf-8") as f:
            for r in results:
                entry = {
                    "audio": r.audio_path,
                    "text": r.ground_truth,
                    "language": "fr"
                }
                f.write(json.dumps(entry, ensure_ascii=False) + "\n")
        
        print(f"‚úÖ {data_type}: {len(results)} segments ‚Üí {output_file.name}")
    
    def save_metadata_json(self, results: List[TranscriptionResult], output_file: Path):
        """
        Sauvegarder les m√©tadonn√©es compl√®tes en JSON
        
        Args:
            results: TranscriptionResult √† sauvegarder
            output_file: Chemin du fichier JSON
        """
        
        data = [
            {
                "segment_id": r.segment_id,
                "speaker": r.speaker,
                "file_name": r.file_name,
                "audio_path": r.audio_path,
                "ground_truth": r.ground_truth,
                "whisper_prediction": r.whisper_prediction,
                "duration_ms": r.duration_ms,
                "wer": r.wer,
                "confidence": r.confidence
            }
            for r in results
        ]
        
        with open(output_file, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2, ensure_ascii=False)
        
        print(f"‚úÖ Metadata: {len(results)} segments ‚Üí {output_file.name}")
    
    def create_complete_dataset(self, results: List[TranscriptionResult],
                               train_ratio: float = 0.8,
                               min_wer: float = None,
                               max_wer: float = None):
        """
        Cr√©er le dataset complet: train/test en JSONL + metadata JSON
        
        Args:
            results: R√©sultats filtr√©s (enfants seulement)
            train_ratio: Ratio train/test
            min_wer: WER minimum
            max_wer: WER maximum
        """
        
        print("\n" + "="*70)
        print("üì¶ CREATING TRAINING DATASET")
        print("="*70 + "\n")
        
        # Cr√©er splits
        splits = self.create_splits(results, train_ratio, min_wer, max_wer)
        
        print(f"   Train: {len(splits['train'])} segments")
        print(f"   Test:  {len(splits['test'])} segments\n")
        
        # Sauvegarder en JSONL (pour fine-tuning)
        self.save_jsonl(splits["train"], self.output_dir / "train.jsonl", "Training")
        self.save_jsonl(splits["test"], self.output_dir / "eval.jsonl", "Evaluation")
        
        # Sauvegarder metadata (pour analyse)
        self.save_metadata_json(splits["train"], self.output_dir / "train_metadata.json")
        self.save_metadata_json(splits["test"], self.output_dir / "eval_metadata.json")
        
        print("\n" + "="*70)
        print("‚úÖ DATASET CREATED")
        print("="*70)
        print(f"\nüìÅ Fichiers g√©n√©r√©s:")
        print(f"   train.jsonl ................. {len(splits['train'])} segments")
        print(f"   eval.jsonl .................. {len(splits['test'])} segments")
        print(f"   train_metadata.json ......... Stats compl√®tes (train)")
        print(f"   eval_metadata.json .......... Stats compl√®tes (eval)")
        print(f"\nüìÇ Emplacement: {self.output_dir}\n")


# ============================================================================
# ZONE 4: PIPELINE ORCHESTRATION
# ============================================================================

def create_whisper_dataset_pipeline(audio_segments: List[Dict],
                                   matched_pairs: List[Tuple[Path, Path]],
                                   output_dir: Path,
                                   whisper_model: str = "base",
                                   sample_size: int = None,
                                   train_ratio: float = 0.8):
    """
    Pipeline complet: Whisper ‚Üí WER ‚Üí Filter ‚Üí Dataset
    
    Args:
        audio_segments: R√©sultats de AudioSegmenter.segment_all()
        matched_pairs: Liste de paires (cha_file, audio_file)
        output_dir: Dossier de sortie
        whisper_model: Mod√®le Whisper √† utiliser
        sample_size: √âvaluer N segments seulement (pour test)
        train_ratio: Ratio train/test
    """
    
    print("\n" + "="*70)
    print("üöÄ WHISPER EVALUATION & DATASET PIPELINE")
    print("="*70)
    
    # √âTAPE 1: √âvaluation Whisper
    print("\n1Ô∏è‚É£  STEP 1: Evaluate Whisper baseline")
    print("-"*70)
    
    evaluator = BaselineEvaluator(whisper_model=whisper_model)
    results = evaluator.evaluate(audio_segments, sample_size=sample_size, batch_size=50)
    
    print_wer_statistics(results)
    
    # √âTAPE 2: Filtrer enfants
    print("2Ô∏è‚É£  STEP 2: Filter children voices")
    print("-"*70)
    
    # Extraire les m√©tadonn√©es speakers de tous les .cha
    all_speakers_info = {}
    for cha_file, _ in matched_pairs:
        speakers_info = ChildrenVoiceFilter.extract_speakers_from_cha(cha_file)
        all_speakers_info.update(speakers_info)
    
    children, adults, unknown = filter_children_results(results, all_speakers_info)
    
    print_filtering_report(children, adults, unknown)
    
    # √âTAPE 3: Cr√©er dataset
    print("3Ô∏è‚É£  STEP 3: Create training dataset")
    print("-"*70)
    
    builder = TrainingDatasetBuilder(output_dir / "training_dataset")
    builder.create_complete_dataset(
        children,
        train_ratio=train_ratio,
        min_wer=None,  # Inclure tous les enfants
        max_wer=None
    )
    
    return {
        "all_results": results,
        "children": children,
        "adults": adults,
        "unknown": unknown
    }


# ============================================================================
# MAIN: Utiliser le pipeline
# ============================================================================

if __name__ == "__main__":
    # Configuration
    output_dir = Path("output/whisper_evaluation")
    
    # Exemple: tu dois avoir les audio_segments du AudioSegmenter
    # Ici on montre juste la structure
    
    print("üìù Exemple d'utilisation:")
    print("\nfrom data_processing_notebook import *")
    print("from whisper_evaluation_dataset import *")
    print("\n# 1. Processing")
    print("pipeline = DataProcessingPipeline(cha_dir, audio_dir, output_dir)")
    print("results = pipeline.run()")
    print("\n# 2. Whisper + Dataset")
    print("create_whisper_dataset_pipeline(")
    print("    audio_segments=results['extracted'],")
    print("    matched_pairs=...,")
    print("    output_dir=output_dir,")
    print("    sample_size=500  # Optionnel: test sur 500 segments")
    print(")")