In [None]:
"""
Few-Shot QA Generator with Enhanced Quality Filtering

Features:
- Multiple manually constructed exemplars per question type (3 high-quality examples)
- Enhanced prompt engineering with diverse exemplar patterns
- Same essential filtering as one-shot: length requirements + duplicate detection
- Improved generation guidance through pattern diversity
- Maintains all statistical tracking and quality controls

"""

import pandas as pd
import json
import logging
import os
import sys
import uuid
from typing import List, Dict, Tuple, Optional
import numpy as np
from sentence_transformers import SentenceTransformer, util
import time
import re
from dataclasses import dataclass
from enum import Enum
from collections import OrderedDict

# OpenAI import fix with proper error handling
try:
    from openai import OpenAI
    OPENAI_V1 = True
    logger_msg = "Using OpenAI v1.x API"
except ImportError:
    try:
        import openai
        OPENAI_V1 = False
        logger_msg = "Using OpenAI legacy API"
    except ImportError:
        print("ERROR: OpenAI library not installed. Run: pip install openai")
        sys.exit(1)

# Configure environment and logging
os.environ["TOKENIZERS_PARALLELISM"] = "false"

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('few_shot_qa_generation.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

# Device configuration with proper torch import
try:
    import torch
    device = torch.device("mps" if torch.backends.mps.is_available() else 
                         "cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
except ImportError:
    print("ERROR: PyTorch not installed. Run: pip install torch")
    sys.exit(1)

logger.info(logger_msg)


class QuestionType(Enum):
    """Enumeration of supported question types for few-shot generation."""
    FACTUAL = "factual"
    RELATIONSHIP = "relationship" 
    COMPARATIVE = "comparative"
    INFERENTIAL = "inferential"


class FilteringDecision(Enum):
    """Basic filtering decision outcomes."""
    ACCEPTED = "accepted"
    REJECTED_LENGTH = "rejected_length"
    REJECTED_DUPLICATE = "rejected_duplicate"
    REJECTED_PARSING = "rejected_parsing"


@dataclass
class QualityThresholds:
    """Basic quality thresholds for essential filtering only."""
    # Length requirements - ensures meaningful content
    min_question_words: int = 8      # Ensures meaningful questions
    min_answer_words: int = 20       # Ensures informative answers
    
    # Duplicate detection thresholds - more lenient for variety
    duplicate_similarity_threshold: float = 0.85  
    batch_similarity_threshold: float = 0.85     
    
    # Semantic similarity weights for relevance scoring (logging only)
    context_weight: float = 0.6
    triple_weight: float = 0.4


@dataclass
class FilteringResult:
    """Data class representing a basic filtering result with reasoning."""
    accepted: bool
    decision: FilteringDecision
    relevance_score: float  # For logging/analysis only
    context_similarity: float
    triple_similarity: float
    reasoning: str
    metadata: Dict[str, float]


@dataclass
class QAPair:
    """Data class representing a question-answer pair with complete metadata."""
    id: str
    question: str
    answer: str
    question_type: QuestionType
    qa_metadata: Dict[str, List[str]]
    filtering_result: FilteringResult
    generation_method: str = "few_shot"
    # Add ground truth information
    source_context: str = ""
    source_triples: List[Tuple] = None
    chunk_id: str = ""
    
    def __post_init__(self):
        if self.source_triples is None:
            self.source_triples = []


# FEW-SHOT EXEMPLARS - Multiple high-quality examples for pattern learning
FEW_SHOT_EXEMPLARS = [
    {
        "context": "In determining student eligibility, institutions must also satisfy themselves that there is a reasonable likelihood that the student will be able to complete their study programme before seeking funding for the student. This should include the practicality of providing a place for a student who may be unable to complete their programme if they are likely to leave the country permanently during their study programme. For the purposes of this paragraph, institutions must assume that all European Economic Area students resident in the UK before 1 January 2022 have the legal right to remain in the UK for the duration of their study programme. Once a student is enrolled, the institution is expected to take all reasonable steps to ensure that the student can complete their programme.",
        "knowledge_graph": [
            ("institution", "has_legal_duty", "verify_student_eligibility"),
            ("institution", "is_expected_to_ensure", "student_program_completion"),
            ("student", "enrolled_in", "study_programme"),
            ("institution", "must_evaluate", "student_completion_likelihood"),
            ("student", "potential_withdrawal_reason", "likely_permanent_departure"),
            ("eea_student", "has_legal_right", "remain_in_uk_during_study"),
            ("eea_student", "has_status", "resident_in_uk_before_20220101"),
            ("institution", "provides_assistance", "student_completion_support")
        ],
        "exemplar_questions": {
            QuestionType.FACTUAL: {
                "question": "What must institutions assume about EEA students who were resident in the UK before 1 January 2022?"
            },
            QuestionType.RELATIONSHIP: {
                "question": "How is a student's likelihood of permanent departure related to their eligibility for funding?"
            },
            QuestionType.COMPARATIVE: {
                "question": "How does the institution's responsibility differ before and after a student is enrolled in a study programme?"
            },
            QuestionType.INFERENTIAL: {
                "question": "Why might institutions be discouraged from enrolling students who are likely to leave the UK permanently during their study programme?"
            }
        }
    },
    {
        "context": "Students who are attending programmes of more than one term's duration, and are eligible for funding at the start of their programme, will usually be eligible for funding for the whole duration of their study programme as well as subsequent funded study programmes studied immediately end-on to their initial funded programme. This includes students studying consecutive study programmes with no break in studies other than normal holiday periods. Similarly, students who are not eligible for funding at the start of their study programme are very unlikely to become eligible for funding during the period of their study programme.",
        "knowledge_graph": [
            ("student", "has_funding_start_status", "eligible_at_programme_start"),
            ("student", "enrolled_in", "study_programme"),
            ("study_programme", "has_duration_value", "more_than_one_term"),
            ("study_programme", "has_funding", "funded_for_duration"),
            ("funded_for_duration", "includes", "subsequent_funded_programmes"),
            ("subsequent_funded_programmes", "has_temporal_value", "immediately_end_on"),
            ("student", "participates_in", "consecutive_study_programmes"),
            ("consecutive_study_programmes", "has_time_period", "no_break_other_than_holidays"),
            ("student", "has_funding_start_status", "not_eligible_at_programme_start"),
            ("student", "related_to_funding_status", "unlikely_to_become_eligible_during_programme")
        ],
        "exemplar_questions": {
            QuestionType.FACTUAL: {
                "question": "Which students are usually eligible for funding throughout the duration of their study programme?"
            },
            QuestionType.RELATIONSHIP: {
                "question": "How does the start-of-programme funding status relate to a student's funding eligibility during their studies?"
            },
            QuestionType.COMPARATIVE: {
                "question": "What is the difference in funding eligibility between students with and without a break between consecutive study programmes?"
            },
            QuestionType.INFERENTIAL: {
                "question": "What does the policy imply about ESFA's approach to students who begin their programme ineligible for funding?"
            }
        }
    },
    {
        "context": "For the Prince's Trust Team Programme, the institution overhead rate (management fee) should be no more than a maximum of 15 per cent of the total ESFA funding. Any figure above 15 per cent will require prior approval from ESFA in collaboration with the Prince's Trust. For the purpose of the condition of funding, ESFA recognise that the Team Programme will support young people to progress towards General Certificate of Secondary Education standard and has been approved as a stepping stone towards a General Certificate of Secondary Education in these subjects.",
        "knowledge_graph": [
            ("esfa", "provides_funding", "princes_trust_team_programme"),
            ("princes_trust_team_programme", "has_funding_condition", "maximum_management_fee_15_percent"),
            ("maximum_management_fee_15_percent", "requires", "prior_approval_from_esfa_above_15_percent"),
            ("prior_approval_from_esfa_above_15_percent", "involves", "collaboration_with_princes_trust"),
            ("esfa", "recognizes", "princes_trust_team_programme_as_stepping_stone"),
            ("princes_trust_team_programme", "supports_learning", "general_certificate_of_secondary_education_standard"),
            ("princes_trust_team_programme", "has_progression", "general_certificate_of_secondary_education")
        ],
        "exemplar_questions": {
            QuestionType.FACTUAL: {
                "question": "What is the maximum management fee allowed for the Prince's Trust Team Programme without requiring ESFA approval?"
            },
            QuestionType.RELATIONSHIP: {
                "question": "How is the Prince's Trust involved in the approval process when the management fee exceeds 15%?"
            },
            QuestionType.COMPARATIVE: {
                "question": "How does the recognition of the Team Programme differ from a full General Certificate of Secondary Education?"
            },
            QuestionType.INFERENTIAL: {
                "question": "Why might ESFA recognize the Team Programme as a stepping stone towards GCSE standard?"
            }
        }
    }
]

# Enhanced few-shot question generation templates
FEW_SHOT_QUESTION_TEMPLATES = {
    QuestionType.FACTUAL: {
        "task_description": (
            "Factual questions seek specific, concrete information that can be directly extracted from "
            "the text or inferred from the knowledge graph relationships. They typically start with "
            "'What', 'When', 'Where', 'Who', or 'How much/many' and ask for precise details."
        ),
        "generation_guidance": (
            "Generate factual questions that require **specific** information from the context. "
            "These questions should ask for concrete details, numbers, dates, names, or specific requirements "
            "mentioned in the text. Focus on extracting precise information that can be directly answered "
            "from the provided context and knowledge graph triples."
        )
    },
    
    QuestionType.RELATIONSHIP: {
        "task_description": (
            "Relationship questions explore connections between entities. They examine how entities "
            "interact, depend on each other, or influence one another through the relationships "
            "defined in the knowledge graph."
        ),
        "generation_guidance": (
            "Generate questions about specific relationships or interactions between entities in the context. "
            "Focus on how different entities, organizations, processes, or concepts connect, influence, or "
            "interact with each other. Each question must reference at least two entities and explore "
            "their connection through the knowledge graph relationships."
        )
    },
    
    QuestionType.COMPARATIVE: {
        "task_description": (
            "Comparative questions examine differences and similarities between entities or concepts. "
            "They help understand distinctions in requirements, processes, amounts, or characteristics "
            "across different categories or instances."
        ),
        "generation_guidance": (
            "Generate questions comparing different aspects, entities, or concepts from the context. "
            "These questions should highlight differences, similarities, or contrasts between multiple "
            "items such as funding types, requirements, processes, or organizational structures. "
            "Use the knowledge graph to identify comparable entities."
        )
    },
    
    QuestionType.INFERENTIAL: {
        "task_description": (
            "Inferential questions require reasoning and synthesis of multiple pieces of information. "
            "They ask for conclusions, implications, or predictions that must be derived by combining "
            "various facts and relationships from the context and knowledge graph."
        ),
        "generation_guidance": (
            "Generate questions that require analysis, reasoning, or inference based on the context "
            "and knowledge graph. These questions should combine multiple pieces of information to "
            "draw conclusions, identify implications, or predict outcomes. They require synthesizing "
            "information from multiple knowledge graph triples."
        )
    }
}


class FewShotQAGenerator:
    """
    Few-Shot QA Generator with Enhanced Quality Filtering.
    
    Features:
    - Uses multiple high-quality exemplars per question type to guide generation
    - Pattern diversity through varied exemplar structures
    - Length requirements for meaningful content
    - Duplicate detection to prevent redundancy 
    - Basic relevance scoring for analysis (not filtering)
    - Simplified decision logic with clear justification
    - Centralized statistics tracking 
    """
    
    def __init__(self, 
                 openai_api_key: str,
                 reset_duplicates_per_chunk: bool = True,
                 embedding_model: str = 'BAAI/bge-large-en-v1.5',
                 max_embedding_cache: int = 1000):
        """
        Initialize the Few-Shot QA Generator.
        
        Args:
            openai_api_key: OpenAI API key for GPT-4 access
            reset_duplicates_per_chunk: Whether to reset duplicate tracking per chunk
            max_embedding_cache: Maximum number of embeddings to cache (memory management)
        """
        if not openai_api_key or openai_api_key == "your-openai-api-key-here":
            raise ValueError("Please provide a valid OpenAI API key")
            
        self.openai_api_key = openai_api_key
        self.max_embedding_cache = max_embedding_cache
        
        # Initialize OpenAI client based on version
        if OPENAI_V1:
            self.openai_client = OpenAI(api_key=openai_api_key)
        else:
            openai.api_key = openai_api_key
            self.openai_client = None
        
        # Initialize quality thresholds
        self.thresholds = QualityThresholds()
        self.reset_duplicates_per_chunk = reset_duplicates_per_chunk
        
        # Initialize embedding model
        logger.info(f"Loading embedding model: {embedding_model}")
        try:
            self.embedding_model = SentenceTransformer(embedding_model, device=device)
            logger.info(f"Successfully loaded {embedding_model}")
        except Exception as e:
            logger.error(f"Error loading {embedding_model}: {e}")
            raise RuntimeError(f"Failed to load the specified embedding model '{embedding_model}'. Please check your internet connection and ensure the model name is correct.")
        
        # Use OrderedDict for memory management and better duplicate tracking
        self.all_questions: set = set()
        self.question_embeddings: OrderedDict[str, torch.Tensor] = OrderedDict()
        self.batch_questions: set = set()
        self.batch_embeddings: Dict[str, torch.Tensor] = {}
        
        # Enhanced statistics tracking with detailed breakdown
        self.generation_stats = {
            "total_api_calls": 0,
            "total_qa_attempts": 0,
            "successful_generations": 0,
            "filtering_decisions": {
                "accepted": 0,
                "rejected_length": 0,
                "rejected_duplicate": 0,
                "rejected_parsing": 0
            },
            "all_attempts": []
        }
        
        logger.info("Few-shot QA Generator initialized successfully")
        logger.info("Using multi-exemplar guidance with essential filters: length requirements + duplicate detection")

    def _filtering_result_to_dict(self, result: FilteringResult) -> Dict:
        """Convert FilteringResult to dictionary for JSON serialization with enum handling."""
        return {
            "accepted": result.accepted,
            "decision": result.decision.value if hasattr(result.decision, 'value') else str(result.decision),
            "relevance_score": result.relevance_score,
            "context_similarity": result.context_similarity,
            "triple_similarity": result.triple_similarity,
            "reasoning": result.reasoning,
            "metadata": result.metadata
        }

    def _manage_embedding_cache(self):
        """Manage embedding cache size to prevent memory issues."""
        if len(self.question_embeddings) > self.max_embedding_cache:
            # Remove oldest embeddings (FIFO)
            removed_count = len(self.question_embeddings) - self.max_embedding_cache
            for _ in range(removed_count):
                oldest_key = next(iter(self.question_embeddings))
                del self.question_embeddings[oldest_key]
                self.all_questions.discard(oldest_key)
            logger.debug(f"Removed {removed_count} old embeddings from cache")

    def calculate_relevance_score(self, question: str, answer: str, context: str, triples: List[Tuple]) -> Tuple[float, float, float]:
        """
        Calculate semantic relevance score for logging/analysis purposes only.
        This is NOT used for filtering decisions.
        """
        try:
            qa_combined = f"{question} {answer}"
            qa_embedding = self.embedding_model.encode(qa_combined, convert_to_tensor=True)
            
            # Context similarity
            context_embedding = self.embedding_model.encode(context, convert_to_tensor=True)
            context_similarity = util.cos_sim(context_embedding, qa_embedding).item()
            
            # Triple similarities
            triple_similarities = []
            for triple in triples:
                try:
                    if len(triple) >= 3:
                        n1, edge, n2 = str(triple[0]), str(triple[1]), str(triple[2])
                        triple_text = f"{n1} {edge} {n2}"
                        triple_embedding = self.embedding_model.encode(triple_text, convert_to_tensor=True)
                        triple_similarity = util.cos_sim(triple_embedding, qa_embedding).item()
                        triple_similarities.append(triple_similarity)
                except (IndexError, TypeError) as e:
                    logger.debug(f"Invalid triple format: {triple}, error: {e}")
                    continue
            
            max_triple_similarity = max(triple_similarities) if triple_similarities else 0.0
            
            # Weighted overall relevance (for logging only)
            overall_relevance = (context_similarity * self.thresholds.context_weight + 
                               max_triple_similarity * self.thresholds.triple_weight)
            
            return overall_relevance, context_similarity, max_triple_similarity
            
        except Exception as e:
            logger.warning(f"Error calculating relevance score: {e}")
            return 0.0, 0.0, 0.0

    def passes_length_filter(self, question: str, answer: str) -> bool:
        """Check if QA pair meets minimum length requirements."""
        try:
            q_words = len(question.split())
            a_words = len(answer.split())
            
            if q_words < self.thresholds.min_question_words or a_words < self.thresholds.min_answer_words:
                logger.debug(f"Length filter failed: Q={q_words} words (min {self.thresholds.min_question_words}), "
                            f"A={a_words} words (min {self.thresholds.min_answer_words})")
                return False
            
            return True
        except Exception as e:
            logger.warning(f"Error in length filter: {e}")
            return False

    def is_semantic_duplicate(self, new_question: str) -> Tuple[bool, torch.Tensor]:
        """
        Check if question is a semantic duplicate and return embedding.
        
        Returns:
            Tuple of (is_duplicate, question_embedding)
        """
        try:
            new_embedding = self.embedding_model.encode(new_question, convert_to_tensor=True)
            
            # Check against historical questions
            for existing_question, existing_embedding in self.question_embeddings.items():
                similarity = util.cos_sim(new_embedding, existing_embedding).item()
                if similarity > self.thresholds.duplicate_similarity_threshold:
                    logger.debug(f"Historical duplicate detected: '{new_question}' ~ '{existing_question}' "
                               f"(similarity={similarity:.3f})")
                    return True, new_embedding
            
            # Check against current batch using stored embeddings
            for batch_question, batch_embedding in self.batch_embeddings.items():
                similarity = util.cos_sim(new_embedding, batch_embedding).item()
                if similarity > self.thresholds.batch_similarity_threshold:
                    logger.debug(f"Batch duplicate detected: '{new_question}' ~ '{batch_question}' "
                               f"(similarity={similarity:.3f})")
                    return True, new_embedding
            
            return False, new_embedding
            
        except Exception as e:
            logger.warning(f"Error in duplicate detection: {e}")
            # Return False with dummy embedding on error
            dummy_embedding = torch.zeros(384)  # Default embedding size
            return False, dummy_embedding

    def filter_qa_pair(self, 
                      question: str, 
                      answer: str, 
                      context: str, 
                      triples: List[Tuple],
                      qa_metadata: Dict) -> Tuple[bool, FilteringResult]:
        """
        Apply basic filtering pipeline to QA pair with centralized statistics tracking.
        
        Args:
            question: Question text
            answer: Answer text
            context: Original text chunk
            triples: Knowledge graph triples
            qa_metadata: QA metadata including entities and relationships
            
        Returns:
            Tuple of (acceptance_decision, filtering_result_details)
        """
        # Centralized statistics tracking - increment total attempts here
        self.generation_stats["total_qa_attempts"] += 1
        
        # Calculate relevance score first (for all attempts)
        relevance_score, context_sim, triple_sim = self.calculate_relevance_score(
            question, answer, context, triples
        )
        
        # Create base metadata
        base_metadata = {
            "answer_length": len(answer.split()),
            "question_length": len(question.split()),
            "mentioned_entities": len(qa_metadata.get("mentioned_entities", [])),
            "mentioned_relationships": len(qa_metadata.get("mentioned_relationships", []))
        }
        
        # Filter 1: Length requirements
        if not self.passes_length_filter(question, answer):
            result = FilteringResult(
                accepted=False,
                decision=FilteringDecision.REJECTED_LENGTH,
                relevance_score=relevance_score,
                context_similarity=context_sim,
                triple_similarity=triple_sim,
                reasoning=f"Failed length requirements: Q={len(question.split())} words, A={len(answer.split())} words",
                metadata=base_metadata
            )
            
            # Update statistics
            self.generation_stats["filtering_decisions"]["rejected_length"] += 1
            
            # Store attempt for analysis
            self.generation_stats["all_attempts"].append({
                "question": question,
                "answer": answer,
                "filtering_result": self._filtering_result_to_dict(result),
                "relevance_score": relevance_score
            })
            
            return False, result
        
        # Filter 2: Duplicate detection 
        is_duplicate, question_embedding = self.is_semantic_duplicate(question)
        if is_duplicate:
            result = FilteringResult(
                accepted=False,
                decision=FilteringDecision.REJECTED_DUPLICATE,
                relevance_score=relevance_score,
                context_similarity=context_sim,
                triple_similarity=triple_sim,
                reasoning="Semantic duplicate detected",
                metadata=base_metadata
            )
            
            # Update statistics
            self.generation_stats["filtering_decisions"]["rejected_duplicate"] += 1
            
            # Store attempt for analysis
            self.generation_stats["all_attempts"].append({
                "question": question,
                "answer": answer,
                "filtering_result": self._filtering_result_to_dict(result),
                "relevance_score": relevance_score
            })
            
            return False, result
        
        # Accept QA pair - passed all filters
        result = FilteringResult(
            accepted=True,
            decision=FilteringDecision.ACCEPTED,
            relevance_score=relevance_score,
            context_similarity=context_sim,
            triple_similarity=triple_sim,
            reasoning=f"Passed basic quality filters (relevance: {relevance_score:.3f})",
            metadata=base_metadata
        )
        
        # Register accepted question with proper embedding storage
        self.all_questions.add(question)
        self.question_embeddings[question] = question_embedding
        self.batch_questions.add(question)
        self.batch_embeddings[question] = question_embedding
        
        # Manage cache size
        self._manage_embedding_cache()
        
        # Update statistics
        self.generation_stats["successful_generations"] += 1
        self.generation_stats["filtering_decisions"]["accepted"] += 1
        
        # Store attempt for analysis
        self.generation_stats["all_attempts"].append({
            "question": question,
            "answer": answer,
            "filtering_result": self._filtering_result_to_dict(result),
            "relevance_score": relevance_score
        })
        
        return True, result

    def load_and_merge_data(self, chunks_file: str, triples_file: str) -> Dict:
        """Load and merge chunks and triples data."""
        try:
            # Load chunks data
            chunks_df = pd.read_csv(chunks_file)
            logger.info(f"Loaded {len(chunks_df)} chunks from {chunks_file}")
            logger.info(f"Chunks columns: {list(chunks_df.columns)}")
            
            # Load triples data
            triples_df = pd.read_csv(triples_file)
            logger.info(f"Loaded {len(triples_df)} triples from {triples_file}")
            logger.info(f"Triples columns: {list(triples_df.columns)}")
            
            # Check for required columns
            if 'chunk_id' not in chunks_df.columns:
                raise ValueError("chunks_file must contain 'chunk_id' column")
            if 'text' not in chunks_df.columns:
                raise ValueError("chunks_file must contain 'text' column")
            if 'chunk_id' not in triples_df.columns:
                raise ValueError("triples_file must contain 'chunk_id' column")
            
            # Try to identify triple columns
            triple_columns = None
            possible_combinations = [
                ('subject', 'predicate', 'object'),
                ('node1', 'edge', 'node2'),
                ('head', 'relation', 'tail'),
                ('entity1', 'relationship', 'entity2')
            ]
            
            for combo in possible_combinations:
                if all(col in triples_df.columns for col in combo):
                    triple_columns = combo
                    break
            
            if triple_columns is None:
                available_cols = list(triples_df.columns)
                raise ValueError(f"Could not identify triple columns. Available columns: {available_cols}. "
                               f"Expected one of: {possible_combinations}")
            
            logger.info(f"Using triple columns: {triple_columns}")
            
            # Merge on chunk_id
            merged_df = pd.merge(chunks_df, triples_df, on='chunk_id', how='inner')
            logger.info(f"Merged data contains {len(merged_df)} records")
            
            if len(merged_df) == 0:
                raise ValueError("No matching chunk_ids found between chunks and triples files")
            
            # Group by chunk_id
            grouped_data = {}
            for chunk_id, group in merged_df.groupby('chunk_id'):
                grouped_data[chunk_id] = {
                    'text': group['text'].iloc[0],
                    'triples': [(row[triple_columns[0]], row[triple_columns[1]], row[triple_columns[2]]) 
                              for _, row in group.iterrows()]
                }
            
            logger.info(f"Grouped data contains {len(grouped_data)} unique chunks")
            return grouped_data
            
        except Exception as e:
            logger.error(f"Error loading and merging data: {e}")
            raise

    def generate_few_shot_prompt(self, context: str, triples: List[Tuple], question_type: QuestionType, num_questions: int) -> str:
        """Generate few-shot prompt with multiple exemplars for enhanced pattern learning."""
        template = FEW_SHOT_QUESTION_TEMPLATES.get(question_type, {})
        task_description = template.get("task_description", "")
        generation_guidance = template.get("generation_guidance", "")
        
        # Format current context triples
        formatted_triples = []
        for triple in triples[:20]:  # Limit to prevent prompt overflow
            try:
                if len(triple) >= 3:
                    n1, e, n2 = str(triple[0]), str(triple[1]), str(triple[2])
                    formatted_triples.append(f"- {n1} → '{e}' → {n2}")
            except (IndexError, TypeError):
                continue
        
        triple_text = "\n".join(formatted_triples) if formatted_triples else "No valid triples available"

        # Build multiple exemplar demonstrations
        exemplar_demonstrations = []
        for i, exemplar in enumerate(FEW_SHOT_EXEMPLARS, 1):
            # Format exemplar triples
            exemplar_formatted_triples = []
            for triple in exemplar["knowledge_graph"]:
                n1, e, n2 = str(triple[0]), str(triple[1]), str(triple[2])
                exemplar_formatted_triples.append(f"- {n1} → '{e}' → {n2}")
            exemplar_triple_text = "\n".join(exemplar_formatted_triples)
            
            # Get the exemplar question for this type
            exemplar_question = exemplar["exemplar_questions"][question_type]["question"]
            
            demonstration = f"""
EXAMPLE {i} CONTEXT:
{exemplar["context"][:800]}{"..." if len(exemplar["context"]) > 800 else ""}

EXAMPLE {i} KNOWLEDGE GRAPH TRIPLES:
{exemplar_triple_text}

EXAMPLE {i} {question_type.value.upper()} QUESTION:
{exemplar_question}
"""
            exemplar_demonstrations.append(demonstration)

        # Combine all demonstrations
        all_demonstrations = "\n".join(exemplar_demonstrations)

        # Create comprehensive few-shot prompt
        prompt = f"""
FEW-SHOT QA GENERATION TASK

TASK TYPE: {question_type.value.upper()}
{task_description}

MULTIPLE EXEMPLAR DEMONSTRATIONS:
{all_demonstrations}


NOW GENERATE FOR NEW CONTEXT:

TARGET CONTEXT:
{context[:1200]}{"..." if len(context) > 1200 else ""}

TARGET KNOWLEDGE GRAPH TRIPLES:
{triple_text}

GENERATION INSTRUCTIONS: {generation_guidance}

ENHANCED DIVERSITY REQUIREMENTS:
- Study the patterns from ALL {len(FEW_SHOT_EXEMPLARS)} examples above
- Each question must be UNIQUE and ask about DIFFERENT aspects
- Use VARIED question starters and phrasing patterns inspired by the examples
- Focus on DIFFERENT entities, relationships, or information types
- Avoid repetitive structures or similar wordings
- Make each question distinctly different from others and from ALL exemplars

REQUIRED OUTPUT FORMAT:
[
  {{
    "id": "1",
    "question": "Your detailed question here?",
    "answer": "Your comprehensive answer here.",
    "type": "{question_type.value}",
    "qa_metadata": {{
      "mentioned_entities": ["entity1", "entity2"],
      "mentioned_relationships": ["relationship1", "relationship2"]
    }}
  }}
]

Generate {num_questions} {question_type.value} questions with answers, drawing inspiration from the diverse patterns shown in the {len(FEW_SHOT_EXEMPLARS)} examples above.
"""
        return prompt.strip()
    
    def parse_json_response(self, text_response: str) -> List[Dict]:
        """Parse JSON response from GPT-4 with improved error handling."""
        # Remove common markdown formatting that can break JSON parsing
        text_response = text_response.strip()
        if text_response.startswith('```json'):
            text_response = text_response[7:]  # Remove ```json
        if text_response.endswith('```'):
            text_response = text_response[:-3]  # Remove trailing ```
        text_response = text_response.strip()
        
        try:
            parsed = json.loads(text_response)
            if isinstance(parsed, list):
                return parsed
            elif isinstance(parsed, dict) and "question" in parsed:
                return [parsed]
        except json.JSONDecodeError as e:
            logger.debug(f"Initial JSON parse failed: {e}")
        
        # Fallback parsing strategies
        try:
            json_start = text_response.find('[')
            json_end = text_response.rfind(']') + 1
            if json_start >= 0 and json_end > json_start:
                json_str = text_response[json_start:json_end]
                parsed = json.loads(json_str)
                if isinstance(parsed, list):
                    return parsed
        except Exception as e:
            logger.debug(f"Array extraction failed: {e}")
        
        try:
            json_start = text_response.find('{')
            json_end = text_response.rfind('}') + 1
            if json_start >= 0 and json_end > json_start:
                json_str = text_response[json_start:json_end]
                parsed = json.loads(json_str)
                if isinstance(parsed, dict) and "question" in parsed:
                    return [parsed]
        except Exception as e:
            logger.debug(f"Object extraction failed: {e}")
        
        logger.warning(f"Failed to parse JSON response: {text_response[:200]}...")
        return []

    def call_openai_api(self, system_message: str, user_message: str, max_tokens: int = 4000, temperature: float = 0.7) -> str:
        """Call OpenAI API with proper version handling and temperature bounds."""
        try:
            # Ensure temperature is within valid bounds
            temperature = min(max(temperature, 0.0), 1.0)
            
            if OPENAI_V1:
                response = self.openai_client.chat.completions.create(
                    model="gpt-4-turbo-preview",
                    messages=[
                        {"role": "system", "content": system_message},
                        {"role": "user", "content": user_message}
                    ],
                    max_tokens=max_tokens,
                    temperature=temperature,
                    presence_penalty=0.3,
                    frequency_penalty=0.3
                )
                return response.choices[0].message.content.strip()
            else:
                response = openai.ChatCompletion.create(
                    model="gpt-4-turbo-preview",
                    messages=[
                        {"role": "system", "content": system_message},
                        {"role": "user", "content": user_message}
                    ],
                    max_tokens=max_tokens,
                    temperature=temperature,
                    presence_penalty=0.3,
                    frequency_penalty=0.3
                )
                return response.choices[0].message.content.strip()
        except Exception as e:
            logger.error(f"OpenAI API call failed: {e}")
            raise

    def generate_qa_pairs(self, 
                         context: str, 
                         triples: List[Tuple], 
                         question_type: QuestionType, 
                         num_questions: int,
                         chunk_id: str) -> List[QAPair]:
        """Generate QA pairs using few-shot learning approach with basic filtering system."""
        qa_pairs = []
        max_retries = 3
        
        for attempt in range(max_retries):
            try:
                # Count API calls (separate from QA attempts)
                self.generation_stats["total_api_calls"] += 1
                
                prompt = self.generate_few_shot_prompt(context, triples, question_type, num_questions)
                
                system_message = """You are an expert at generating domain-specific QA pairs using few-shot learning with multiple exemplar guidance. 

You have been provided with multiple high-quality exemplars that demonstrate diverse patterns and approaches for each question type. Use these exemplars as guides to understand the range of possibilities and generate similar high-quality questions for the new context provided.

Key principles:
- Study ALL provided examples to understand pattern diversity
- Generate questions that demonstrate understanding of the domain relationships and entities
- Follow the varied patterns established by the multiple exemplars
- Ensure each question explores different aspects of the context
- Maintain the quality and depth shown in the examples

Your output MUST be valid JSON in the exact format specified. Generate questions that demonstrate understanding of the domain relationships and entities present in the provided knowledge graph, following the diverse patterns established by the exemplars.

IMPORTANT: In the qa_metadata field, accurately list the specific entities and relationships that you mention in your question and answer."""

                logger.debug(f"Generating {num_questions} {question_type.value} questions using few-shot learning (attempt {attempt + 1})")
                
                # Temperature progression with bounds
                temperature = min(0.8 + (attempt * 0.1), 1.0)
                
                text_response = self.call_openai_api(
                    system_message, 
                    prompt, 
                    temperature=temperature
                )
                
                parsed_items = self.parse_json_response(text_response)
                
                if parsed_items:
                    logger.info(f"Successfully parsed {len(parsed_items)} QA pairs from few-shot response")
                    
                    for item in parsed_items:
                        q_text = item.get("question", "").strip()
                        a_text = item.get("answer", "").strip()
                        qa_metadata = item.get("qa_metadata", {
                            "mentioned_entities": [],
                            "mentioned_relationships": []
                        })

                        if q_text and a_text:
                            # Apply filtering (statistics are handled inside filter_qa_pair)
                            accepted, filtering_result = self.filter_qa_pair(
                                q_text, a_text, context, triples, qa_metadata
                            )
                            
                            if accepted:
                                unique_id = f"{question_type.value}_fewshot_{chunk_id}_{int(time.time())}_{str(uuid.uuid4())[:8]}"
                                qa_pair = QAPair(
                                    id=unique_id,
                                    question=q_text,
                                    answer=a_text,
                                    question_type=question_type,  # Keep as enum for internal use
                                    qa_metadata=qa_metadata,
                                    filtering_result=filtering_result,
                                    generation_method="few_shot",
                                    source_context=context,
                                    source_triples=triples.copy(),
                                    chunk_id=chunk_id
                                )
                                
                                qa_pairs.append(qa_pair)
                                
                                logger.debug(f"Accepted QA pair: {filtering_result.decision.value} "
                                           f"(relevance={filtering_result.relevance_score:.3f})")
                                
                                if len(qa_pairs) >= num_questions:
                                    return qa_pairs[:num_questions]
                            else:
                                logger.debug(f"Rejected QA pair: {filtering_result.reasoning}")
                        else:
                            # Handle parsing failures - empty question or answer
                            self.generation_stats["total_qa_attempts"] += 1
                            self.generation_stats["filtering_decisions"]["rejected_parsing"] += 1
                            
                            # Store failed parsing attempt
                            dummy_result = FilteringResult(
                                accepted=False,
                                decision=FilteringDecision.REJECTED_PARSING,
                                relevance_score=0.0,
                                context_similarity=0.0,
                                triple_similarity=0.0,
                                reasoning="Empty question or answer from parsing",
                                metadata={}
                            )
                            
                            self.generation_stats["all_attempts"].append({
                                "question": q_text,
                                "answer": a_text,
                                "filtering_result": self._filtering_result_to_dict(dummy_result),
                                "relevance_score": 0.0
                            })

            except Exception as e:
                logger.error(f"Error in generation attempt {attempt + 1}: {str(e)}")
                time.sleep(min(2 ** attempt, 10))  # Exponential backoff with cap
                continue

        return qa_pairs

    def create_few_shot_dataset(self, 
                               merged_grouped: Dict, 
                               output_file: str, 
                               limit: Optional[int] = None) -> None:
        """Create a complete few-shot QA dataset with basic filtering."""
        logger.info("Starting few-shot QA dataset creation with multi-exemplar guidance")
        
        all_qa_pairs = []
        question_types = list(QuestionType)
        
        # Process chunks
        chunks_to_process = list(merged_grouped.items())
        if limit:
            chunks_to_process = chunks_to_process[:limit]
            
        total_chunks = len(chunks_to_process)
        logger.info(f"Processing {total_chunks} chunks with few-shot learning")
        
        for i, (chunk_id, chunk_data) in enumerate(chunks_to_process, 1):
            logger.info(f"Processing chunk {i}/{total_chunks}: {chunk_id}")
            
            # Reset batch tracking for each chunk if specified
            if self.reset_duplicates_per_chunk:
                self.batch_questions.clear()
                self.batch_embeddings.clear()
            
            context = chunk_data['text']
            triples = chunk_data['triples']
            
            if not triples:
                logger.warning(f"No triples found for chunk {chunk_id}, skipping")
                continue
            
            # Generate QA pairs for each question type using few-shot learning
            for question_type in question_types:
                try:
                    logger.debug(f"Generating {question_type.value} questions for chunk {chunk_id} using few-shot learning")
                    
                    qa_pairs = self.generate_qa_pairs(
                        context=context,
                        triples=triples,
                        question_type=question_type,
                        num_questions=2,
                        chunk_id=chunk_id
                    )
                    
                    all_qa_pairs.extend(qa_pairs)
                    logger.info(f"Generated {len(qa_pairs)} {question_type.value} QA pairs for chunk {chunk_id}")
                    
                except Exception as e:
                    logger.error(f"Error generating {question_type.value} questions for chunk {chunk_id}: {e}")
                    continue
            
            # Add delay between chunks to avoid rate limiting
            if i < total_chunks:
                time.sleep(2)
        
        # Get embedding model name safely
        try:
            if hasattr(self.embedding_model, 'model_name'):
                embedding_model_name = self.embedding_model.model_name
            elif hasattr(self.embedding_model, '_model_name'):
                embedding_model_name = self.embedding_model._model_name
            elif hasattr(self.embedding_model, 'config') and hasattr(self.embedding_model.config, 'name_or_path'):
                embedding_model_name = self.embedding_model.config.name_or_path
            else:
                try:
                    embedding_model_name = self.embedding_model._modules['0'].auto_model.config.name_or_path
                except:
                    embedding_model_name = 'BAAI/bge-large-en-v1.5'
        except:
            embedding_model_name = 'BAAI/bge-large-en-v1.5'
        
        # Calculate question type statistics with proper enum handling
        question_type_stats = {}
        question_type_quality = {}
        
        for qa_pair in all_qa_pairs:
            q_type = qa_pair.question_type.value  # Convert enum to string here
            
            # Count by type
            if q_type not in question_type_stats:
                question_type_stats[q_type] = {
                    "count": 0,
                    "avg_relevance_score": 0.0,
                    "avg_question_length": 0.0,
                    "avg_answer_length": 0.0,
                    "avg_context_similarity": 0.0,
                    "avg_triple_similarity": 0.0
                }
                question_type_quality[q_type] = {
                    "relevance_scores": [],
                    "question_lengths": [],
                    "answer_lengths": [],
                    "context_similarities": [],
                    "triple_similarities": []
                }
            
            question_type_stats[q_type]["count"] += 1
            
            # Collect quality metrics
            question_type_quality[q_type]["relevance_scores"].append(qa_pair.filtering_result.relevance_score)
            question_type_quality[q_type]["question_lengths"].append(len(qa_pair.question.split()))
            question_type_quality[q_type]["answer_lengths"].append(len(qa_pair.answer.split()))
            question_type_quality[q_type]["context_similarities"].append(qa_pair.filtering_result.context_similarity)
            question_type_quality[q_type]["triple_similarities"].append(qa_pair.filtering_result.triple_similarity)
        
        # Calculate averages for each question type
        for q_type in question_type_stats:
            metrics = question_type_quality[q_type]
            question_type_stats[q_type]["avg_relevance_score"] = np.mean(metrics["relevance_scores"]) if metrics["relevance_scores"] else 0.0
            question_type_stats[q_type]["avg_question_length"] = np.mean(metrics["question_lengths"]) if metrics["question_lengths"] else 0.0
            question_type_stats[q_type]["avg_answer_length"] = np.mean(metrics["answer_lengths"]) if metrics["answer_lengths"] else 0.0
            question_type_stats[q_type]["avg_context_similarity"] = np.mean(metrics["context_similarities"]) if metrics["context_similarities"] else 0.0
            question_type_stats[q_type]["avg_triple_similarity"] = np.mean(metrics["triple_similarities"]) if metrics["triple_similarities"] else 0.0

        # Create final dataset structure with enhanced metadata
        dataset = {
            "metadata": {
                "creation_date": pd.Timestamp.now().isoformat(),
                "total_queries": len(all_qa_pairs),
                "generation_method": "few_shot",
                "model_version": "1.0",
                "embedding_model": embedding_model_name,
                "filtering_approach": "essential_quality_filters",
                "exemplar_info": {
                    "num_exemplars": len(FEW_SHOT_EXEMPLARS),
                    "exemplar_contexts_length": [len(ex["context"]) for ex in FEW_SHOT_EXEMPLARS],
                    "exemplar_triples_count": [len(ex["knowledge_graph"]) for ex in FEW_SHOT_EXEMPLARS],
                    "pattern_diversity": "Multiple contexts covering different policy aspects",
                    "question_type_coverage": list(FEW_SHOT_EXEMPLARS[0]["exemplar_questions"].keys())
                },
                "quality_thresholds": {
                    "min_question_words": self.thresholds.min_question_words,
                    "min_answer_words": self.thresholds.min_answer_words,
                    "duplicate_similarity_threshold": self.thresholds.duplicate_similarity_threshold,
                    "batch_similarity_threshold": self.thresholds.batch_similarity_threshold
                },
                "generation_statistics": self.get_filtering_statistics(),
                "question_type_breakdown": question_type_stats,
                "processing_summary": {
                    "chunks_processed": total_chunks,
                    "questions_per_chunk": 2 * len(question_types),  # 2 per question type
                    "question_types_generated": [q_type.value for q_type in question_types],  # Convert enums to strings
                    "overall_quality": {
                        "avg_relevance_score": np.mean([qa.filtering_result.relevance_score for qa in all_qa_pairs]) if all_qa_pairs else 0.0,
                        "avg_question_length": np.mean([len(qa.question.split()) for qa in all_qa_pairs]) if all_qa_pairs else 0.0,
                        "avg_answer_length": np.mean([len(qa.answer.split()) for qa in all_qa_pairs]) if all_qa_pairs else 0.0,
                        "avg_context_similarity": np.mean([qa.filtering_result.context_similarity for qa in all_qa_pairs]) if all_qa_pairs else 0.0,
                        "avg_triple_similarity": np.mean([qa.filtering_result.triple_similarity for qa in all_qa_pairs]) if all_qa_pairs else 0.0
                    }
                }
            },
            "queries": []
        }
        
        # Convert QA pairs to dictionary format with proper enum serialization
        for qa_pair in all_qa_pairs:
            query_dict = {
                "id": qa_pair.id,
                "question": qa_pair.question,
                "answer": qa_pair.answer,
                "question_type": qa_pair.question_type.value,  # Convert enum to string
                "qa_metadata": qa_pair.qa_metadata,
                "filtering_result": {
                    "accepted": qa_pair.filtering_result.accepted,
                    "decision": qa_pair.filtering_result.decision.value,  # Convert enum to string
                    "relevance_score": qa_pair.filtering_result.relevance_score,
                    "context_similarity": qa_pair.filtering_result.context_similarity,
                    "triple_similarity": qa_pair.filtering_result.triple_similarity,
                    "reasoning": qa_pair.filtering_result.reasoning,
                    "metadata": qa_pair.filtering_result.metadata
                },
                "generation_method": qa_pair.generation_method,
                # Add ground truth information for evaluation
                "ground_truth": {
                    "source_context": qa_pair.source_context,
                    "source_triples": [
                        {
                            "subject": str(triple[0]),  # Ensure string conversion
                            "predicate": str(triple[1]), 
                            "object": str(triple[2])
                        } for triple in qa_pair.source_triples if len(triple) >= 3
                    ],
                    "chunk_id": qa_pair.chunk_id,
                    "context_length_chars": len(qa_pair.source_context),
                    "num_source_triples": len(qa_pair.source_triples)
                }
            }
            dataset["queries"].append(query_dict)
        
        # Ensure all data is JSON serializable before saving
        dataset = self._ensure_json_serializable(dataset)
        
        # Save dataset
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(dataset, f, indent=2, ensure_ascii=False)
        
        logger.info(f"Few-shot dataset creation completed!")
        logger.info(f"Total QA pairs generated: {len(all_qa_pairs)}")
        logger.info(f"Dataset saved to: {output_file}")
        
        # Log detailed statistics
        stats = self.get_filtering_statistics()
        logger.info("\nFew-Shot Basic Filtering Statistics:")
        logger.info(f"Total API calls: {stats.get('total_api_calls', 0)}")
        logger.info(f"Total QA attempts: {stats.get('total_qa_attempts', 0)}")
        logger.info(f"Successful generations: {stats.get('successful_generations', 0)}")
        
        logger.info("\nFiltering Decision Breakdown:")
        for decision_type, count in stats["filtering_decisions"].items():
            if count > 0:
                percentage = (count / stats['total_qa_attempts']) * 100 if stats['total_qa_attempts'] > 0 else 0
                logger.info(f"  {decision_type}: {count} ({percentage:.1f}%)")
        
        if "acceptance_rate" in stats:
            logger.info(f"\nOverall Acceptance Rate: {stats['acceptance_rate']:.1f}%")
        
        # Log question type breakdown
        logger.info("\nQuestion Type Breakdown:")
        for q_type, type_stats in question_type_stats.items():
            count = type_stats["count"]
            avg_relevance = type_stats["avg_relevance_score"]
            percentage = (count / len(all_qa_pairs)) * 100 if all_qa_pairs else 0
            logger.info(f"  {q_type}: {count} questions ({percentage:.1f}%) - avg relevance: {avg_relevance:.3f}")
        
        # Validation check
        total_decisions = sum(stats["filtering_decisions"].values())
        if total_decisions == stats["total_qa_attempts"]:
            logger.info("\n Statistics validation: PASSED")
        else:
            logger.warning(f"\n Statistics validation: FAILED ({total_decisions} decisions vs {stats['total_qa_attempts']} attempts)")

    def _ensure_json_serializable(self, obj):
        """Recursively ensure all objects in nested structure are JSON serializable."""
        if isinstance(obj, dict):
            return {key: self._ensure_json_serializable(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [self._ensure_json_serializable(item) for item in obj]
        elif isinstance(obj, tuple):
            return [self._ensure_json_serializable(item) for item in obj]
        elif isinstance(obj, (QuestionType, FilteringDecision)):
            return obj.value
        elif hasattr(obj, '__dict__'):
            # Handle custom objects by converting to dict
            return self._ensure_json_serializable(obj.__dict__)
        else:
            return obj

    def get_filtering_statistics(self) -> Dict:
        """Get comprehensive statistics about filtering decisions with JSON serialization safety."""
        stats = self.generation_stats.copy()
        
        # Calculate acceptance rate based on QA pairs processed
        total_qa_attempts = stats["total_qa_attempts"]
        if total_qa_attempts > 0:
            acceptance_rate = (stats["successful_generations"] / total_qa_attempts) * 100
            stats["acceptance_rate"] = acceptance_rate
        else:
            stats["acceptance_rate"] = 0.0
        
        # Add validation check
        total_decisions = sum(stats["filtering_decisions"].values())
        if total_decisions != total_qa_attempts:
            logger.warning(f"Statistics mismatch: {total_decisions} decisions vs {total_qa_attempts} attempts")
            stats["validation_status"] = "FAILED"
        else:
            stats["validation_status"] = "PASSED"
        
        # Ensure all nested objects are JSON serializable
        return self._ensure_json_serializable(stats)


# Analysis and utility functions

def analyze_few_shot_filtering_results(dataset_path: str, generator_stats: Optional[Dict] = None) -> Dict:
    """
    Analyze the results of few-shot filtering using both dataset and generation statistics.
    
    Args:
        dataset_path: Path to the generated dataset
        generator_stats: Optional generator statistics for complete analysis
    """
    try:
        with open(dataset_path, 'r', encoding='utf-8') as f:
            dataset = json.load(f)
        
        queries = dataset.get("queries", [])
        metadata = dataset.get("metadata", {})
        
        if not queries:
            return {"error": "No queries found in dataset"}
        
        # Get actual generation statistics from metadata if available
        generation_stats = metadata.get("generation_statistics", {})
        if generator_stats:
            generation_stats = generator_stats
        
        # Analyze accepted queries (those in the final dataset)
        accepted_count = len(queries)
        
        # Calculate quality metrics for accepted queries
        quality_metrics = {
            "relevance_scores": [],
            "question_lengths": [],
            "answer_lengths": [],
            "context_similarities": [],
            "triple_similarities": []
        }
        
        question_type_distribution = {}
        
        for query in queries:
            filtering_info = query.get("filtering_result", {})
            quality_metrics["relevance_scores"].append(filtering_info.get("relevance_score", 0.0))
            quality_metrics["question_lengths"].append(len(query["question"].split()))
            quality_metrics["answer_lengths"].append(len(query["answer"].split()))
            quality_metrics["context_similarities"].append(filtering_info.get("context_similarity", 0.0))
            quality_metrics["triple_similarities"].append(filtering_info.get("triple_similarity", 0.0))
            
            # Count question types
            q_type = query.get("question_type", "unknown")
            question_type_distribution[q_type] = question_type_distribution.get(q_type, 0) + 1
        
        # Use generation statistics for complete picture
        total_attempts = generation_stats.get("total_qa_attempts", accepted_count)
        rejected_count = total_attempts - accepted_count
        
        # Calculate comprehensive statistics
        analysis_results = {
            "total_qa_attempts": total_attempts,
            "final_dataset_queries": accepted_count,
            "decision_distribution": {
                "accepted_count": accepted_count,
                "rejected_count": rejected_count,
                "acceptance_rate": (accepted_count / total_attempts) * 100 if total_attempts > 0 else 0
            },
            "detailed_filtering_breakdown": generation_stats.get("filtering_decisions", {}),
            "question_type_distribution": question_type_distribution,
            "quality_statistics": {
                "avg_relevance_score": np.mean(quality_metrics["relevance_scores"]) if quality_metrics["relevance_scores"] else 0,
                "min_relevance_score": min(quality_metrics["relevance_scores"]) if quality_metrics["relevance_scores"] else 0,
                "max_relevance_score": max(quality_metrics["relevance_scores"]) if quality_metrics["relevance_scores"] else 0,
                "std_relevance_score": np.std(quality_metrics["relevance_scores"]) if quality_metrics["relevance_scores"] else 0,
                "avg_question_length": np.mean(quality_metrics["question_lengths"]) if quality_metrics["question_lengths"] else 0,
                "avg_answer_length": np.mean(quality_metrics["answer_lengths"]) if quality_metrics["answer_lengths"] else 0,
                "avg_context_similarity": np.mean(quality_metrics["context_similarities"]) if quality_metrics["context_similarities"] else 0,
                "avg_triple_similarity": np.mean(quality_metrics["triple_similarities"]) if quality_metrics["triple_similarities"] else 0
            },
            "generation_efficiency": {
                "api_calls": generation_stats.get("total_api_calls", 0),
                "qa_per_api_call": accepted_count / generation_stats.get("total_api_calls", 1),
                "acceptance_rate": generation_stats.get("acceptance_rate", 0)
            },
            "exemplar_effectiveness": metadata.get("exemplar_info", {}),
            "pattern_diversity_analysis": {
                "num_exemplars_used": metadata.get("exemplar_info", {}).get("num_exemplars", 0),
                "pattern_coverage": "Multiple regulatory contexts with diverse structures",
                "learning_approach": "Few-shot pattern recognition from varied examples"
            }
        }
        
        return analysis_results
        
    except Exception as e:
        logger.error(f"Error analyzing few-shot filtering results: {e}")
        return {"error": str(e)}


def generate_few_shot_report(dataset_path: str, output_report_path: str, generator_stats: Optional[Dict] = None) -> None:
    """
    Generate a comprehensive report on few-shot filtering performance.
    
    Args:
        dataset_path: Path to the few-shot filtering dataset
        output_report_path: Path to save the filtering report
        generator_stats: Optional generator statistics for complete analysis
    """
    try:
        # Analyze with complete statistics
        analysis = analyze_few_shot_filtering_results(dataset_path, generator_stats)
        
        if "error" in analysis:
            logger.error(f"Analysis failed: {analysis['error']}")
            return
        
        # Generate comprehensive report
        report = {
            "report_metadata": {
                "generation_date": pd.Timestamp.now().isoformat(),
                "dataset_analyzed": dataset_path,
                "report_type": "Few_Shot_QA_Analysis",
                "approach": "Multi-exemplar guided generation with essential quality filtering"
            },
            "executive_summary": {
                "total_qa_attempts": analysis.get("total_qa_attempts", 0),
                "final_questions_generated": analysis.get("final_dataset_queries", 0),
                "overall_acceptance_rate": f"{analysis.get('decision_distribution', {}).get('acceptance_rate', 0):.1f}%",
                "avg_relevance_score": f"{analysis.get('quality_statistics', {}).get('avg_relevance_score', 0):.3f}",
                "api_efficiency": f"{analysis.get('generation_efficiency', {}).get('qa_per_api_call', 0):.1f} QA pairs per API call"
            },
            "detailed_analysis": analysis,
            "few_shot_approach": {
                "method": "Multi-exemplar guided generation using diverse high-quality template questions",
                "exemplar_sources": "Multiple manually constructed regulatory texts with KG triples",
                "pattern_diversity": "Varied contexts covering different policy aspects and structures",
                "question_types_covered": ["factual", "relationship", "comparative", "inferential"],
                "filtering_approach": "Essential quality filters: length requirements + duplicate detection"
            },
            "advantages_of_few_shot": [
                "Provides diverse pattern recognition through multiple exemplars",
                "Enables learning from varied question structures and approaches",
                "Reduces over-fitting to single exemplar patterns",
                "Improves generalization across different contexts",
                "Maintains quality while increasing structural variety",
                "Better coverage of domain-specific language patterns",
                "Enhanced robustness through pattern diversity"
            ],
            "pattern_learning_insights": {
                "exemplar_diversity": f"Used {analysis.get('exemplar_effectiveness', {}).get('num_exemplars', 0)} diverse exemplars",
                "structural_variety": "Multiple question formulation approaches learned",
                "domain_adaptation": "Better adaptation to regulatory language patterns",
                "quality_consistency": "Maintained quality across diverse pattern types"
            }
        }
        
        # Save report
        with open(output_report_path, 'w', encoding='utf-8') as f:
            json.dump(report, f, indent=2, ensure_ascii=False)
        
        logger.info(f"Few-shot filtering report generated: {output_report_path}")
        
    except Exception as e:
        logger.error(f"Error generating few-shot filtering report: {e}")


def main():
    """
    Main function to demonstrate the few-shot learning QA generation system.
    """
    logger.info("Starting Few-Shot QA Generation with Multi-Exemplar Guidance")
    
    try:
        # Get API key
        openai_api_key = input("Please enter your OpenAI API key: ").strip()
        
        if not openai_api_key or openai_api_key == "your-openai-api-key-here":
            logger.error("Please provide a valid OpenAI API key")
            return
        
        # Initialize generator with few-shot learning
        generator = FewShotQAGenerator(
            openai_api_key=openai_api_key,
            reset_duplicates_per_chunk=True,
            embedding_model='BAAI/bge-large-en-v1.5',
            max_embedding_cache=1000
        )

        # File paths
        chunks_file = input("Enter path to chunks CSV file (or press Enter for 'chunks.csv'): ").strip() or "chunks.csv"
        triples_file = input("Enter path to triples CSV file (or press Enter for 'Ontology_Guided_Triples.csv'): ").strip() or "Ontology_Guided_Triples.csv"
        output_file = input("Enter output file name (or press Enter for 'Few-Shot_qa_dataset.json'): ").strip() or "Few-Shot_qa_dataset.json"

        # Check if files exist
        if not os.path.exists(chunks_file):
            logger.error(f"Chunks file not found: {chunks_file}")
            logger.info("Please ensure your chunks file exists and has the correct path")
            return
            
        if not os.path.exists(triples_file):
            logger.error(f"Triples file not found: {triples_file}")
            logger.info("Please ensure your triples file exists and has the correct path")
            return

        # Load and merge data
        logger.info("Loading and merging data files")
        merged_data = generator.load_and_merge_data(chunks_file, triples_file)

        if not merged_data:
            logger.error("No data loaded. Please check your input files.")
            return
        
        # Get user input for number of chunks to process
        try:
            limit = input("Enter number of chunks to process (or press Enter for all): ").strip()
            limit = int(limit) if limit else None
        except ValueError:
            limit = None
            logger.info("Using default: processing all chunks")
        
        if limit is None:
            logger.info("Processing all chunks")
        else:
            logger.info(f"Processing {limit} chunks")
        
        # Create few-shot learning dataset
        logger.info(f"Creating Few-Shot QA dataset using multi-exemplar guidance")
        generator.create_few_shot_dataset(
            merged_grouped=merged_data,
            output_file=output_file,
            limit=limit
        )

        # Get final statistics from generator
        final_stats = generator.get_filtering_statistics()

        # Generate analysis report
        logger.info("Analyzing few-shot filtering results")
        analysis_results = analyze_few_shot_filtering_results(output_file, final_stats)
        
        if "error" in analysis_results:
            logger.error(f"Analysis failed: {analysis_results['error']}")
        else:
            logger.info("Few-Shot Filtering Results:")
            logger.info(f"  Total QA attempts: {analysis_results.get('total_qa_attempts', 0)}")
            logger.info(f"  Final questions generated: {analysis_results.get('final_dataset_queries', 0)}")
            logger.info(f"  Overall acceptance rate: {analysis_results.get('decision_distribution', {}).get('acceptance_rate', 0):.1f}%")
            logger.info(f"  Average relevance score: {analysis_results.get('quality_statistics', {}).get('avg_relevance_score', 0):.3f}")

        # Generate comprehensive report
        report_file = "Few_Shot_QA_analysis_report.json"
        generate_few_shot_report(output_file, report_file, final_stats)
        
        # Display final statistics
        logger.info("\nFINAL Few-Shot Statistics:")
        logger.info(f"Total API calls: {final_stats.get('total_api_calls', 0)}")
        logger.info(f"Total QA attempts: {final_stats.get('total_qa_attempts', 0)}")
        logger.info(f"Successful generations: {final_stats.get('successful_generations', 0)}")

        logger.info("\nFiltering Decision Breakdown:")
        for decision_type, count in final_stats["filtering_decisions"].items():
            if count > 0:
                percentage = (count / final_stats['total_qa_attempts']) * 100 if final_stats['total_qa_attempts'] > 0 else 0
                logger.info(f"  {decision_type}: {count} ({percentage:.1f}%)")
        
        if "acceptance_rate" in final_stats:
            logger.info(f"\nOverall Acceptance Rate: {final_stats['acceptance_rate']:.1f}%")
        
        # Validation check
        validation_status = final_stats.get("validation_status", "UNKNOWN")
        logger.info(f"Statistics validation: {validation_status}")
        
        logger.info(f"\nOutputs generated:")
        logger.info(f"  Dataset: {output_file}")
        logger.info(f"  Analysis Report: {report_file}")
        logger.info("\nFew-Shot QA generation completed successfully!")
        logger.info("\nFew-Shot Learning Features:")
        logger.info(f"   ✓ Multi-exemplar guidance ({len(FEW_SHOT_EXEMPLARS)} diverse examples)")
        logger.info("   ✓ Pattern diversity learning from varied contexts")
        logger.info("   ✓ Enhanced structural variety in questions")
        logger.info("   ✓ Essential filtering: length + duplicate detection")
        logger.info("   ✓ Balanced coverage across question types")
        logger.info("   ✓ Robust domain-specific pattern learning")

    except KeyboardInterrupt:
        logger.info("\nProcess interrupted by user")
    except Exception as e:
        logger.error(f"Error in main execution: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        sys.exit(1)

if __name__ == "__main__":
    main()