In [None]:
"""
Ablation Study 2: KG-Only QA Generator (No Text Context)

This is the second controlled ablation study derived from the One-Shot QA Generator.
It removes dependency on text context and uses only knowledge graph triples to evaluate 
their standalone contribution to QA quality.

Key Changes from Original:
- Removes all text context dependencies 
- Uses KG-only exemplars and generation
- Maintains same filtering pipeline and architecture
- Preserves all quality controls and statistics tracking

Purpose: Evaluate the impact of text context vs. pure knowledge graph information on QA generation quality.


"""

import pandas as pd
import json
import logging
import os
import sys
import uuid
from typing import List, Dict, Tuple, Optional
import numpy as np
from sentence_transformers import SentenceTransformer, util
import time
import re
from dataclasses import dataclass
from enum import Enum
from collections import OrderedDict

# OpenAI import fix with proper error handling
try:
    from openai import OpenAI
    OPENAI_V1 = True
    logger_msg = "Using OpenAI v1.x API"
except ImportError:
    try:
        import openai
        OPENAI_V1 = False
        logger_msg = "Using OpenAI legacy API"
    except ImportError:
        print("ERROR: OpenAI library not installed. Run: pip install openai")
        sys.exit(1)

# Configure environment and logging
os.environ["TOKENIZERS_PARALLELISM"] = "false"

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('ablation_kg_only_qa_generation.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

# Device configuration with proper torch import
try:
    import torch
    device = torch.device("mps" if torch.backends.mps.is_available() else 
                         "cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
except ImportError:
    print("ERROR: PyTorch not installed. Run: pip install torch")
    sys.exit(1)

logger.info(logger_msg)


class QuestionType(Enum):
    """Enumeration of supported question types for ablation study."""
    FACTUAL = "factual"
    RELATIONSHIP = "relationship" 
    COMPARATIVE = "comparative"
    INFERENTIAL = "inferential"


class FilteringDecision(Enum):
    """Basic filtering decision outcomes."""
    ACCEPTED = "accepted"
    REJECTED_LENGTH = "rejected_length"
    REJECTED_DUPLICATE = "rejected_duplicate"
    REJECTED_PARSING = "rejected_parsing"


@dataclass
class QualityThresholds:
    """Quality thresholds for essential filtering (adapted for KG-only)."""
    min_question_words: int = 8
    min_answer_words: int = 20
    duplicate_similarity_threshold: float = 0.85
    batch_similarity_threshold: float = 0.85
    triple_weight: float = 1.0  # Only triple relevance in this ablation


@dataclass
class FilteringResult:
    """Data class representing filtering result (simplified for KG-only)."""
    accepted: bool
    decision: FilteringDecision
    relevance_score: float  # Now based only on triple similarity
    triple_similarity: float
    reasoning: str
    metadata: Dict[str, float]


@dataclass
class QAPair:
    """Data class for QA pairs (simplified for KG-only ablation)."""
    id: str
    question: str
    answer: str
    question_type: QuestionType
    qa_metadata: Dict[str, List[str]]
    filtering_result: FilteringResult
    generation_method: str = "ablation_kg_only"
    # Ground truth information (only triples)
    source_triples: List[Tuple] = None
    chunk_id: str = ""
    
    def __post_init__(self):
        if self.source_triples is None:
            self.source_triples = []


# KG-ONLY EXEMPLAR (Designed specifically for knowledge graph analysis)
KG_ONLY_EXEMPLAR = {
    "knowledge_graph": [
        ("esfa", "funds", "ehe_children_further_education"),
        ("ehe_children_further_education", "enrolled_in", "further_education_colleges"),
        ("ehe_children_further_education", "enrolled_in", "sixth_form_colleges"),
        ("ehe_children_schools_academies", "not_eligible_for", "esfa_young_people_funding"),
        ("funding_rates_and_formula_guide", "provides_information_on", "esfa_funding_details"),
        ("colleges", "can_claim", "esfa_young_people_funding"),
        ("esfa_young_people_funding", "for_programme", "level_3_course"),
        ("children_compulsory_school_age", "has_achievement_status", "full_level_2_qualification"),
        ("esfa", "does_not_require_approval_from", "colleges_for_lagged_funding"),
        ("schools_and_academies", "applies_same_advice_as", "colleges_for_early_sixth_form_placement"),
        ("esfa", "considers_funding_eligibility_for", "individual_students_compulsory_school_age"),
        ("individual_students_compulsory_school_age", "has_reason", "arriving_in_uk_during_school_year_11"),
        ("groups_of_students", "not_eligible_for", "esfa_young_people_funding_due_to_non_exceptional_circumstances")
    ],
    
    "exemplar_questions": {
        QuestionType.FACTUAL: {
            "question": "Which types of colleges do EHE children attend when funded by the ESFA?"
        },
        
        QuestionType.RELATIONSHIP: {
            "question": "What is the relationship between colleges and ESFA young people's funding for Level 3 courses?"
        },
        
        QuestionType.COMPARATIVE: {
            "question": "How does ESFA funding eligibility differ for individuals versus groups arriving during school year 11?"
        },
        
        QuestionType.INFERENTIAL: {
            "question": "Based on the ESFA's policy regarding approval for lagged funding, what implication can be drawn about the need for direct oversight by ESFA for such claims?"
        }
    }
}

# KG-only question generation templates (removed context references)
KG_ONLY_QUESTION_TEMPLATES = {
    QuestionType.FACTUAL: {
        "task_description": (
            "Factual questions seek specific information that can be directly extracted from "
            "the knowledge graph relationships. They focus on entities, their properties, "
            "and direct connections represented in the triples."
        ),
        "generation_guidance": (
            "Generate factual questions that ask about specific entities, relationships, or properties "
            "directly represented in the knowledge graph triples. Focus on 'what', 'who', 'which' "
            "questions that can be answered by examining the graph structure and entity connections."
        )
    },
    
    QuestionType.RELATIONSHIP: {
        "task_description": (
            "Relationship questions explore direct and indirect connections between entities "
            "in the knowledge graph. They examine how entities are linked through various "
            "relationship types and connection patterns."
        ),
        "generation_guidance": (
            "Generate questions about relationships and connections between entities in the knowledge graph. "
            "Focus on how entities are connected, what relationships exist between them, and how "
            "these connections form meaningful patterns in the graph structure."
        )
    },
    
    QuestionType.COMPARATIVE: {
        "task_description": (
            "Comparative questions examine differences and similarities between entities or "
            "relationship patterns in the knowledge graph. They analyze contrasting paths, "
            "different relationship types, or varying entity properties."
        ),
        "generation_guidance": (
            "Generate questions comparing different entities, relationships, or patterns in the knowledge graph. "
            "These questions should highlight differences in how entities are connected, what relationships "
            "they participate in, or how their graph positions differ from one another."
        )
    },
    
    QuestionType.INFERENTIAL: {
        "task_description": (
            "Inferential questions require reasoning about the knowledge graph structure to "
            "derive insights not explicitly stated in individual triples. They synthesize "
            "multiple relationships to draw conclusions about the domain."
        ),
        "generation_guidance": (
            "Generate questions that require analysis and reasoning about the knowledge graph structure. "
            "These questions should combine information from multiple triples to draw conclusions, "
            "identify implications, or understand broader patterns in the entity relationships."
        )
    }
}


class KGOnlyQAGenerator:
    """
    Ablation Study 2: KG-Only QA Generator (No Text Context Required).
    
    This version removes all dependency on text context and uses only knowledge graph 
    triples to evaluate their standalone contribution to QA generation quality while 
    maintaining the same architecture and filtering pipeline as the original.
    """
    
    def __init__(self, 
                 openai_api_key: str,
                 reset_duplicates_per_chunk: bool = True,
                 embedding_model: str = 'BAAI/bge-large-en-v1.5',
                 max_embedding_cache: int = 1000):
        """
        Initialize the KG-Only QA Generator.
        
        Args:
            openai_api_key: OpenAI API key for GPT-4 access
            reset_duplicates_per_chunk: Whether to reset duplicate tracking per chunk
            max_embedding_cache: Maximum number of embeddings to cache (memory management)
        """
        if not openai_api_key or openai_api_key == "your-openai-api-key-here":
            raise ValueError("Please provide a valid OpenAI API key")
            
        self.openai_api_key = openai_api_key
        self.max_embedding_cache = max_embedding_cache
        
        # Initialize OpenAI client based on version
        if OPENAI_V1:
            self.openai_client = OpenAI(api_key=openai_api_key)
        else:
            openai.api_key = openai_api_key
            self.openai_client = None
        
        # Initialize quality thresholds (adapted for KG-only)
        self.thresholds = QualityThresholds()
        self.reset_duplicates_per_chunk = reset_duplicates_per_chunk
        
        # Initialize embedding model
        logger.info(f"Loading embedding model: {embedding_model}")
        try:
            self.embedding_model = SentenceTransformer(embedding_model, device=device)
            logger.info(f"Successfully loaded {embedding_model}")
        except Exception as e:
            logger.error(f"Error loading {embedding_model}: {e}")
            raise RuntimeError(f"Failed to load the specified embedding model '{embedding_model}'.")
        
        # Duplicate tracking (same as original)
        self.all_questions: set = set()
        self.question_embeddings: OrderedDict[str, torch.Tensor] = OrderedDict()
        self.batch_questions: set = set()
        self.batch_embeddings: Dict[str, torch.Tensor] = {}
        
        # statistics tracking
        self.generation_stats = {
            "total_api_calls": 0,
            "total_qa_attempts": 0,
            "successful_generations": 0,
            "filtering_decisions": {
                "accepted": 0,
                "rejected_length": 0,
                "rejected_duplicate": 0,
                "rejected_parsing": 0
            },
            "all_attempts": []
        }
        
        logger.info("KG-Only QA Generator initialized successfully")
        logger.info("ABLATION STUDY 2: Using knowledge graph triples only (no text context)")

    def _filtering_result_to_dict(self, result: FilteringResult) -> Dict:
        """Convert FilteringResult to dictionary for JSON serialization."""
        return {
            "accepted": result.accepted,
            "decision": result.decision.value,
            "relevance_score": result.relevance_score,
            "triple_similarity": result.triple_similarity,
            "reasoning": result.reasoning,
            "metadata": result.metadata
        }

    def _manage_embedding_cache(self):
        """Manage embedding cache size to prevent memory issues."""
        if len(self.question_embeddings) > self.max_embedding_cache:
            removed_count = len(self.question_embeddings) - self.max_embedding_cache
            for _ in range(removed_count):
                oldest_key = next(iter(self.question_embeddings))
                del self.question_embeddings[oldest_key]
                self.all_questions.discard(oldest_key)
            logger.debug(f"Removed {removed_count} old embeddings from cache")

    def calculate_relevance_score(self, question: str, answer: str, triples: List[Tuple]) -> Tuple[float, float]:
        """
        Calculate KG-only relevance score for logging/analysis.
        Simplified version that only uses triple similarity (no context).
        """
        try:
            qa_combined = f"{question} {answer}"
            qa_embedding = self.embedding_model.encode(qa_combined, convert_to_tensor=True)
            
            # Triple similarities only
            triple_similarities = []
            for triple in triples:
                try:
                    if len(triple) >= 3:
                        n1, edge, n2 = str(triple[0]), str(triple[1]), str(triple[2])
                        triple_text = f"{n1} {edge} {n2}"
                        triple_embedding = self.embedding_model.encode(triple_text, convert_to_tensor=True)
                        triple_similarity = util.cos_sim(triple_embedding, qa_embedding).item()
                        triple_similarities.append(triple_similarity)
                except (IndexError, TypeError) as e:
                    logger.debug(f"Invalid triple format: {triple}, error: {e}")
                    continue
            
            max_triple_similarity = max(triple_similarities) if triple_similarities else 0.0
            
            # Overall relevance is just max triple similarity in this ablation
            overall_relevance = max_triple_similarity
            
            return overall_relevance, max_triple_similarity
            
        except Exception as e:
            logger.warning(f"Error calculating relevance score: {e}")
            return 0.0, 0.0

    def passes_length_filter(self, question: str, answer: str) -> bool:
        """Check if QA pair meets minimum length requirements (unchanged)."""
        try:
            q_words = len(question.split())
            a_words = len(answer.split())
            
            if q_words < self.thresholds.min_question_words or a_words < self.thresholds.min_answer_words:
                logger.debug(f"Length filter failed: Q={q_words} words (min {self.thresholds.min_question_words}), "
                            f"A={a_words} words (min {self.thresholds.min_answer_words})")
                return False
            
            return True
        except Exception as e:
            logger.warning(f"Error in length filter: {e}")
            return False

    def is_semantic_duplicate(self, new_question: str) -> Tuple[bool, torch.Tensor]:
        """Check if question is a semantic duplicate (unchanged from original)."""
        try:
            new_embedding = self.embedding_model.encode(new_question, convert_to_tensor=True)
            
            # Check against historical questions
            for existing_question, existing_embedding in self.question_embeddings.items():
                similarity = util.cos_sim(new_embedding, existing_embedding).item()
                if similarity > self.thresholds.duplicate_similarity_threshold:
                    logger.debug(f"Historical duplicate detected: '{new_question}' ~ '{existing_question}' "
                               f"(similarity={similarity:.3f})")
                    return True, new_embedding
            
            # Check against current batch
            for batch_question, batch_embedding in self.batch_embeddings.items():
                similarity = util.cos_sim(new_embedding, batch_embedding).item()
                if similarity > self.thresholds.batch_similarity_threshold:
                    logger.debug(f"Batch duplicate detected: '{new_question}' ~ '{batch_question}' "
                               f"(similarity={similarity:.3f})")
                    return True, new_embedding
            
            return False, new_embedding
            
        except Exception as e:
            logger.warning(f"Error in duplicate detection: {e}")
            dummy_embedding = torch.zeros(384)
            return False, dummy_embedding

    def filter_qa_pair(self, 
                      question: str, 
                      answer: str, 
                      triples: List[Tuple],
                      qa_metadata: Dict) -> Tuple[bool, FilteringResult]:
        """
        Apply filtering pipeline to QA pair (simplified for KG-only).
        Removed all text context dependencies while maintaining same filtering logic.
        """
        # Centralized statistics tracking
        self.generation_stats["total_qa_attempts"] += 1
        
        # Calculate relevance score (KG-only)
        relevance_score, triple_sim = self.calculate_relevance_score(question, answer, triples)
        
        # Create base metadata (simplified for KG-only)
        base_metadata = {
            "answer_length": len(answer.split()),
            "question_length": len(question.split()),
            "num_source_triples": len(triples)
        }
        
        # Filter 1: Length requirements
        if not self.passes_length_filter(question, answer):
            result = FilteringResult(
                accepted=False,
                decision=FilteringDecision.REJECTED_LENGTH,
                relevance_score=relevance_score,
                triple_similarity=triple_sim,
                reasoning=f"Failed length requirements: Q={len(question.split())} words, A={len(answer.split())} words",
                metadata=base_metadata
            )
            
            self.generation_stats["filtering_decisions"]["rejected_length"] += 1
            self.generation_stats["all_attempts"].append({
                "question": question,
                "answer": answer,
                "filtering_result": self._filtering_result_to_dict(result),
                "relevance_score": relevance_score
            })
            
            return False, result
        
        # Filter 2: Duplicate detection 
        is_duplicate, question_embedding = self.is_semantic_duplicate(question)
        if is_duplicate:
            result = FilteringResult(
                accepted=False,
                decision=FilteringDecision.REJECTED_DUPLICATE,
                relevance_score=relevance_score,
                triple_similarity=triple_sim,
                reasoning="Semantic duplicate detected",
                metadata=base_metadata
            )
            
            self.generation_stats["filtering_decisions"]["rejected_duplicate"] += 1
            self.generation_stats["all_attempts"].append({
                "question": question,
                "answer": answer,
                "filtering_result": self._filtering_result_to_dict(result),
                "relevance_score": relevance_score
            })
            
            return False, result
        
        # Accept QA pair
        result = FilteringResult(
            accepted=True,
            decision=FilteringDecision.ACCEPTED,
            relevance_score=relevance_score,
            triple_similarity=triple_sim,
            reasoning=f"Passed basic quality filters (relevance: {relevance_score:.3f})",
            metadata=base_metadata
        )
        
        # Register accepted question
        self.all_questions.add(question)
        self.question_embeddings[question] = question_embedding
        self.batch_questions.add(question)
        self.batch_embeddings[question] = question_embedding
        
        self._manage_embedding_cache()
        
        # Update statistics
        self.generation_stats["successful_generations"] += 1
        self.generation_stats["filtering_decisions"]["accepted"] += 1
        
        self.generation_stats["all_attempts"].append({
            "question": question,
            "answer": answer,
            "filtering_result": self._filtering_result_to_dict(result),
            "relevance_score": relevance_score
        })
        
        return True, result

    def load_triples_data(self, triples_file: str) -> Dict:
        """
        Load triples data only (no chunks file required for this ablation).
        
        Args:
            triples_file: Path to CSV file containing triples
            
        Returns:
            Dictionary mapping chunk_id to triples
        """
        try:
            # Load triples data (only file needed)
            triples_df = pd.read_csv(triples_file)
            logger.info(f"Loaded {len(triples_df)} triples from {triples_file}")
            logger.info(f"Triples columns: {list(triples_df.columns)}")
            
            # Check for required columns
            if 'chunk_id' not in triples_df.columns:
                raise ValueError("triples_file must contain 'chunk_id' column")
            
            # Try to identify triple columns
            triple_columns = None
            possible_combinations = [
                ('subject', 'predicate', 'object'),
                ('node1', 'edge', 'node2'),
                ('head', 'relation', 'tail'),
                ('entity1', 'relationship', 'entity2')
            ]
            
            for combo in possible_combinations:
                if all(col in triples_df.columns for col in combo):
                    triple_columns = combo
                    break
            
            if triple_columns is None:
                available_cols = list(triples_df.columns)
                raise ValueError(f"Could not identify triple columns. Available columns: {available_cols}. "
                               f"Expected one of: {possible_combinations}")
            
            logger.info(f"Using triple columns: {triple_columns}")
            
            # Group by chunk_id (KG-only approach)
            grouped_data = {}
            for chunk_id, group in triples_df.groupby('chunk_id'):
                grouped_data[chunk_id] = {
                    'text': "",  # Empty string - not used in this ablation
                    'triples': [(row[triple_columns[0]], row[triple_columns[1]], row[triple_columns[2]]) 
                              for _, row in group.iterrows()]
                }
            
            logger.info(f"Processed {len(grouped_data)} unique chunks from triples data")
            return grouped_data
            
        except Exception as e:
            logger.error(f"Error loading triples data: {e}")
            raise

    def generate_kg_only_prompt(self, triples: List[Tuple], question_type: QuestionType, num_questions: int) -> str:
        """Generate prompt using only KG triples (no text context)."""
        template = KG_ONLY_QUESTION_TEMPLATES.get(question_type, {})
        task_description = template.get("task_description", "")
        generation_guidance = template.get("generation_guidance", "")
        
        # Get the exemplar question for this type
        exemplar_question = KG_ONLY_EXEMPLAR["exemplar_questions"][question_type]["question"]
        
        # Format current triples
        formatted_triples = []
        for triple in triples[:25]:  # Limit to prevent prompt overflow but allow more than chunks
            try:
                if len(triple) >= 3:
                    n1, e, n2 = str(triple[0]), str(triple[1]), str(triple[2])
                    formatted_triples.append(f"- {n1} → '{e}' → {n2}")
            except (IndexError, TypeError):
                continue
        
        triple_text = "\n".join(formatted_triples) if formatted_triples else "No valid triples available"
    
        # Format exemplar triples
        exemplar_formatted_triples = []
        for triple in KG_ONLY_EXEMPLAR["knowledge_graph"]:
            n1, e, n2 = str(triple[0]), str(triple[1]), str(triple[2])
            exemplar_formatted_triples.append(f"- {n1} → '{e}' → {n2}")
        exemplar_triple_text = "\n".join(exemplar_formatted_triples)
    
        # Create KG-only prompt (no text context references)
        prompt = f"""
ABLATION STUDY: KG-ONLY QA GENERATION TASK

TASK TYPE: {question_type.value.upper()}
{task_description}

EXEMPLAR DEMONSTRATION:

EXAMPLE KNOWLEDGE GRAPH TRIPLES:
{exemplar_triple_text}

EXAMPLE {question_type.value.upper()} QUESTION:
{exemplar_question}

NOW GENERATE FOR NEW KNOWLEDGE GRAPH:

TARGET KNOWLEDGE GRAPH TRIPLES:
{triple_text}

GENERATION INSTRUCTIONS: {generation_guidance}

IMPORTANT: Generate questions and answers based ONLY on the knowledge graph triples provided. 
Do not reference or require additional text context information.

DIVERSITY REQUIREMENTS:
- Each question must be UNIQUE and ask about DIFFERENT entities or relationships
- Use VARIED question starters and phrasing patterns  
- Focus on DIFFERENT triples, entity connections, or relationship patterns
- Avoid repetitive structures or similar wordings
- Make each question distinctly different from others and from the exemplar

REQUIRED OUTPUT FORMAT:
[
  {{
    "id": "1",
    "question": "Your detailed question here?",
    "answer": "Your comprehensive answer here.",
    "type": "{question_type.value}"
  }}
]

Generate {num_questions} {question_type.value} questions with answers based solely on the knowledge graph triples.
"""
        return prompt.strip()
    
    def parse_json_response(self, text_response: str) -> List[Dict]:
        """Parse JSON response from GPT-4 (unchanged from original)."""
        text_response = text_response.strip()
        if text_response.startswith('```json'):
            text_response = text_response[7:]
        if text_response.endswith('```'):
            text_response = text_response[:-3]
        text_response = text_response.strip()
        
        try:
            parsed = json.loads(text_response)
            if isinstance(parsed, list):
                return parsed
            elif isinstance(parsed, dict) and "question" in parsed:
                return [parsed]
        except json.JSONDecodeError as e:
            logger.debug(f"Initial JSON parse failed: {e}")
        
        # Fallback parsing strategies
        try:
            json_start = text_response.find('[')
            json_end = text_response.rfind(']') + 1
            if json_start >= 0 and json_end > json_start:
                json_str = text_response[json_start:json_end]
                parsed = json.loads(json_str)
                if isinstance(parsed, list):
                    return parsed
        except Exception as e:
            logger.debug(f"Array extraction failed: {e}")
        
        try:
            json_start = text_response.find('{')
            json_end = text_response.rfind('}') + 1
            if json_start >= 0 and json_end > json_start:
                json_str = text_response[json_start:json_end]
                parsed = json.loads(json_str)
                if isinstance(parsed, dict) and "question" in parsed:
                    return [parsed]
        except Exception as e:
            logger.debug(f"Object extraction failed: {e}")
        
        logger.warning(f"Failed to parse JSON response: {text_response[:200]}...")
        return []

    def call_openai_api(self, system_message: str, user_message: str, max_tokens: int = 4000, temperature: float = 0.7) -> str:
        """Call OpenAI API (unchanged from original)."""
        try:
            temperature = min(max(temperature, 0.0), 1.0)
            
            if OPENAI_V1:
                response = self.openai_client.chat.completions.create(
                    model="gpt-4-turbo-preview",
                    messages=[
                        {"role": "system", "content": system_message},
                        {"role": "user", "content": user_message}
                    ],
                    max_tokens=max_tokens,
                    temperature=temperature,
                    presence_penalty=0.3,
                    frequency_penalty=0.3
                )
                return response.choices[0].message.content.strip()
            else:
                response = openai.ChatCompletion.create(
                    model="gpt-4-turbo-preview",
                    messages=[
                        {"role": "system", "content": system_message},
                        {"role": "user", "content": user_message}
                    ],
                    max_tokens=max_tokens,
                    temperature=temperature,
                    presence_penalty=0.3,
                    frequency_penalty=0.3
                )
                return response.choices[0].message.content.strip()
        except Exception as e:
            logger.error(f"OpenAI API call failed: {e}")
            raise

    def generate_qa_pairs(self, 
                         triples: List[Tuple], 
                         question_type: QuestionType, 
                         num_questions: int,
                         chunk_id: str) -> List[QAPair]:
        """Generate QA pairs using KG-only approach (no text context)."""
        qa_pairs = []
        max_retries = 3
        
        for attempt in range(max_retries):
            try:
                self.generation_stats["total_api_calls"] += 1
                
                prompt = self.generate_kg_only_prompt(triples, question_type, num_questions)
                
                system_message = """You are an expert at generating domain-specific QA pairs using knowledge graph triples only (no text context required). 

You have been provided with a high-quality exemplar that demonstrates the expected format and quality for each question type. Use this exemplar as a guide to generate similar high-quality questions for the new knowledge graph provided.

Generate questions that can be answered using ONLY the information present in the knowledge graph triples. Focus on entities, relationships, and graph structure patterns.

Your output MUST be valid JSON in the exact format specified with only id, question, answer, and type fields."""

                logger.debug(f"Generating {num_questions} {question_type.value} questions (attempt {attempt + 1}) - KG ONLY")
                
                temperature = min(0.8 + (attempt * 0.1), 1.0)
                
                text_response = self.call_openai_api(
                    system_message, 
                    prompt, 
                    temperature=temperature
                )
                
                parsed_items = self.parse_json_response(text_response)
                
                if parsed_items:
                    logger.info(f"Successfully parsed {len(parsed_items)} QA pairs from KG-only response")
                    
                    for item in parsed_items:
                        q_text = item.get("question", "").strip()
                        a_text = item.get("answer", "").strip()
                        qa_metadata = {}  # Empty dict since simplified format has no metadata

                        if q_text and a_text:
                            # Apply filtering (only triples passed)
                            accepted, filtering_result = self.filter_qa_pair(
                                q_text, a_text, triples, qa_metadata
                            )
                            
                            if accepted:
                                unique_id = f"{question_type.value}_kg_only_{chunk_id}_{int(time.time())}_{str(uuid.uuid4())[:8]}"
                                qa_pair = QAPair(
                                    id=unique_id,
                                    question=q_text,
                                    answer=a_text,
                                    question_type=question_type,
                                    qa_metadata=qa_metadata,
                                    filtering_result=filtering_result,
                                    generation_method="ablation_kg_only",
                                    source_triples=triples.copy(),
                                    chunk_id=chunk_id
                                )
                                
                                qa_pairs.append(qa_pair)
                                
                                logger.debug(f"Accepted QA pair: {filtering_result.decision.value} "
                                           f"(relevance={filtering_result.relevance_score:.3f})")
                                
                                if len(qa_pairs) >= num_questions:
                                    return qa_pairs[:num_questions]
                            else:
                                logger.debug(f"Rejected QA pair: {filtering_result.reasoning}")
                        else:
                            # Handle parsing failures
                            self.generation_stats["total_qa_attempts"] += 1
                            self.generation_stats["filtering_decisions"]["rejected_parsing"] += 1
                            
                            dummy_result = FilteringResult(
                                accepted=False,
                                decision=FilteringDecision.REJECTED_PARSING,
                                relevance_score=0.0,
                                triple_similarity=0.0,
                                reasoning="Empty question or answer from parsing",
                                metadata={"answer_length": 0, "question_length": 0, "num_source_triples": 0}
                            )
                            
                            self.generation_stats["all_attempts"].append({
                                "question": q_text,
                                "answer": a_text,
                                "filtering_result": self._filtering_result_to_dict(dummy_result),
                                "relevance_score": 0.0
                            })

            except Exception as e:
                logger.error(f"Error in generation attempt {attempt + 1}: {str(e)}")
                time.sleep(min(2 ** attempt, 10))
                continue

        return qa_pairs

    def create_kg_only_dataset(self, 
                              merged_grouped: Dict, 
                              output_file: str, 
                              limit: Optional[int] = None) -> None:
        """Create ablation dataset using knowledge graph only (no text context)."""
        logger.info("Starting ABLATION STUDY 2: KG-Only QA dataset creation")
        
        all_qa_pairs = []
        question_types = list(QuestionType)
        
        # Process chunks
        chunks_to_process = list(merged_grouped.items())
        if limit:
            chunks_to_process = chunks_to_process[:limit]
            
        total_chunks = len(chunks_to_process)
        logger.info(f"Processing {total_chunks} chunks with KG-only approach")
        
        for i, (chunk_id, chunk_data) in enumerate(chunks_to_process, 1):
            logger.info(f"Processing chunk {i}/{total_chunks}: {chunk_id} - ABLATION: KG ONLY")
            
            # Reset batch tracking for each chunk if specified
            if self.reset_duplicates_per_chunk:
                self.batch_questions.clear()
                self.batch_embeddings.clear()
            
            triples = chunk_data['triples']
            
            if not triples:
                logger.warning(f"No triples found for chunk {chunk_id}, skipping")
                continue
            
            # Generate QA pairs for each question type using KG-only approach
            for question_type in question_types:
                try:
                    logger.debug(f"Generating {question_type.value} questions for chunk {chunk_id} using KG-only approach")
                    
                    qa_pairs = self.generate_qa_pairs(
                        triples=triples,
                        question_type=question_type,
                        num_questions=2,
                        chunk_id=chunk_id
                    )
                    
                    all_qa_pairs.extend(qa_pairs)
                    logger.info(f"Generated {len(qa_pairs)} {question_type.value} QA pairs for chunk {chunk_id}")
                    
                except Exception as e:
                    logger.error(f"Error generating {question_type.value} questions for chunk {chunk_id}: {e}")
                    continue
            
            # Add delay between chunks to avoid rate limiting
            if i < total_chunks:
                time.sleep(2)
        
        # Get embedding model name safely
        try:
            if hasattr(self.embedding_model, 'model_name'):
                embedding_model_name = self.embedding_model.model_name
            elif hasattr(self.embedding_model, '_model_name'):
                embedding_model_name = self.embedding_model._model_name
            elif hasattr(self.embedding_model, 'config') and hasattr(self.embedding_model.config, 'name_or_path'):
                embedding_model_name = self.embedding_model.config.name_or_path
            else:
                try:
                    embedding_model_name = self.embedding_model._modules['0'].auto_model.config.name_or_path
                except:
                    embedding_model_name = 'BAAI/bge-large-en-v1.5'
        except:
            embedding_model_name = 'BAAI/bge-large-en-v1.5'
        
        # Calculate question type statistics
        question_type_stats = {}
        question_type_quality = {}
        
        for qa_pair in all_qa_pairs:
            q_type = qa_pair.question_type.value
            
            # Count by type
            if q_type not in question_type_stats:
                question_type_stats[q_type] = {
                    "count": 0,
                    "avg_relevance_score": 0.0,
                    "avg_question_length": 0.0,
                    "avg_answer_length": 0.0,
                    "avg_triple_similarity": 0.0
                }
                question_type_quality[q_type] = {
                    "relevance_scores": [],
                    "question_lengths": [],
                    "answer_lengths": [],
                    "triple_similarities": []
                }
            
            question_type_stats[q_type]["count"] += 1
            
            # Collect quality metrics
            question_type_quality[q_type]["relevance_scores"].append(qa_pair.filtering_result.relevance_score)
            question_type_quality[q_type]["question_lengths"].append(len(qa_pair.question.split()))
            question_type_quality[q_type]["answer_lengths"].append(len(qa_pair.answer.split()))
            question_type_quality[q_type]["triple_similarities"].append(qa_pair.filtering_result.triple_similarity)
        
        # Calculate averages for each question type
        for q_type in question_type_stats:
            metrics = question_type_quality[q_type]
            question_type_stats[q_type]["avg_relevance_score"] = np.mean(metrics["relevance_scores"]) if metrics["relevance_scores"] else 0.0
            question_type_stats[q_type]["avg_question_length"] = np.mean(metrics["question_lengths"]) if metrics["question_lengths"] else 0.0
            question_type_stats[q_type]["avg_answer_length"] = np.mean(metrics["answer_lengths"]) if metrics["answer_lengths"] else 0.0
            question_type_stats[q_type]["avg_triple_similarity"] = np.mean(metrics["triple_similarities"]) if metrics["triple_similarities"] else 0.0

        # Create final dataset structure with metadata
        dataset = {
            "metadata": {
                "creation_date": pd.Timestamp.now().isoformat(),
                "total_queries": len(all_qa_pairs),
                "generation_method": "ablation_kg_only",
                "model_version": "Ablation 2.0",
                "embedding_model": embedding_model_name,
                "filtering_approach": "essential_quality_filters_kg_only",
                "ablation_study": {
                    "study_number": 2,
                    "study_name": "KG Only (No Text Context)",
                    "removed_components": ["text_context", "context_similarity_scoring"],
                    "retained_components": ["triple_similarity", "length_filtering", "duplicate_detection"],
                    "purpose": "Evaluate contribution of text context vs. knowledge graph information to QA quality"
                },
                "exemplar_info": {
                    "num_exemplar_triples": len(KG_ONLY_EXEMPLAR["knowledge_graph"]),
                    "exemplar_question_types": [q_type.value for q_type in KG_ONLY_EXEMPLAR["exemplar_questions"].keys()],
                    "text_context_used": False
                },
                "quality_thresholds": {
                    "min_question_words": self.thresholds.min_question_words,
                    "min_answer_words": self.thresholds.min_answer_words,
                    "duplicate_similarity_threshold": self.thresholds.duplicate_similarity_threshold,
                    "batch_similarity_threshold": self.thresholds.batch_similarity_threshold
                },
                "generation_statistics": self.get_filtering_statistics(),
                "question_type_breakdown": question_type_stats,
                "processing_summary": {
                    "chunks_processed": total_chunks,
                    "questions_per_chunk": 2 * len(question_types),
                    "question_types_generated": list(question_type_stats.keys()),
                    "overall_quality": {
                        "avg_relevance_score": np.mean([qa.filtering_result.relevance_score for qa in all_qa_pairs]) if all_qa_pairs else 0.0,
                        "avg_question_length": np.mean([len(qa.question.split()) for qa in all_qa_pairs]) if all_qa_pairs else 0.0,
                        "avg_answer_length": np.mean([len(qa.answer.split()) for qa in all_qa_pairs]) if all_qa_pairs else 0.0,
                        "avg_triple_similarity": np.mean([qa.filtering_result.triple_similarity for qa in all_qa_pairs]) if all_qa_pairs else 0.0
                    }
                }
            },
            "queries": []
        }
        
        # Convert QA pairs to dictionary format
        for qa_pair in all_qa_pairs:
            query_dict = {
                "id": qa_pair.id,
                "question": qa_pair.question,
                "answer": qa_pair.answer,
                "question_type": qa_pair.question_type.value,
                "qa_metadata": qa_pair.qa_metadata,
                "filtering_result": {
                    "accepted": qa_pair.filtering_result.accepted,
                    "decision": qa_pair.filtering_result.decision.value,
                    "relevance_score": qa_pair.filtering_result.relevance_score,
                    "triple_similarity": qa_pair.filtering_result.triple_similarity,
                    "reasoning": qa_pair.filtering_result.reasoning,
                    "metadata": qa_pair.filtering_result.metadata
                },
                "generation_method": qa_pair.generation_method,
                # Add ground truth information for evaluation
                "ground_truth": {
                    "source_triples": [
                        {
                            "subject": triple[0],
                            "predicate": triple[1], 
                            "object": triple[2]
                        } for triple in qa_pair.source_triples if len(triple) >= 3
                    ],
                    "chunk_id": qa_pair.chunk_id,
                    "num_source_triples": len(qa_pair.source_triples),
                    "text_context_used": False,  # Key difference from original
                    "ablation_study": "kg_only"
                }
            }
            dataset["queries"].append(query_dict)
        
        # Save dataset
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(dataset, f, indent=2, ensure_ascii=False)
        
        logger.info(f"ABLATION STUDY 2: KG-only dataset creation completed!")
        logger.info(f"Total QA pairs generated: {len(all_qa_pairs)}")
        logger.info(f"Dataset saved to: {output_file}")
        
        # Log detailed statistics
        stats = self.get_filtering_statistics()
        logger.info("\nAblation Study 2 - KG-Only Filtering Statistics:")
        logger.info(f"Total API calls: {stats.get('total_api_calls', 0)}")
        logger.info(f"Total QA attempts: {stats.get('total_qa_attempts', 0)}")
        logger.info(f"Successful generations: {stats.get('successful_generations', 0)}")
        
        logger.info("\nFiltering Decision Breakdown:")
        for decision_type, count in stats["filtering_decisions"].items():
            if count > 0:
                percentage = (count / stats['total_qa_attempts']) * 100 if stats['total_qa_attempts'] > 0 else 0
                logger.info(f"  {decision_type}: {count} ({percentage:.1f}%)")
        
        if "acceptance_rate" in stats:
            logger.info(f"\nOverall Acceptance Rate: {stats['acceptance_rate']:.1f}%")
        
        # Log question type breakdown
        logger.info("\nQuestion Type Breakdown:")
        for q_type, type_stats in question_type_stats.items():
            count = type_stats["count"]
            avg_relevance = type_stats["avg_relevance_score"]
            percentage = (count / len(all_qa_pairs)) * 100 if all_qa_pairs else 0
            logger.info(f"  {q_type}: {count} questions ({percentage:.1f}%) - avg relevance: {avg_relevance:.3f}")
        
        # Validation check
        total_decisions = sum(stats["filtering_decisions"].values())
        if total_decisions == stats["total_qa_attempts"]:
            logger.info("\nStatistics validation: PASSED")
        else:
            logger.warning(f"\nStatistics validation: FAILED ({total_decisions} decisions vs {stats['total_qa_attempts']} attempts)")
        
        # Ablation-specific logging
        logger.info("\nABLATION STUDY 2 SUMMARY:")
        logger.info(" Removed: Text context and context similarity scoring")
        logger.info(" Retained: Knowledge graph triples and triple similarity")
        logger.info(" Purpose: Evaluate text context contribution to QA generation quality")

    def get_filtering_statistics(self) -> Dict:
        """Get comprehensive statistics about filtering decisions."""
        stats = self.generation_stats.copy()
        
        # Calculate acceptance rate based on QA pairs processed
        total_qa_attempts = stats["total_qa_attempts"]
        if total_qa_attempts > 0:
            acceptance_rate = (stats["successful_generations"] / total_qa_attempts) * 100
            stats["acceptance_rate"] = acceptance_rate
        else:
            stats["acceptance_rate"] = 0.0
        
        # Add validation check
        total_decisions = sum(stats["filtering_decisions"].values())
        if total_decisions != total_qa_attempts:
            logger.warning(f"Statistics mismatch: {total_decisions} decisions vs {total_qa_attempts} attempts")
            stats["validation_status"] = "FAILED"
        else:
            stats["validation_status"] = "PASSED"
        
        return stats


# Analysis and utility functions for ablation study

def analyze_kg_ablation_results(dataset_path: str, generator_stats: Optional[Dict] = None) -> Dict:
    """
    Analyze the results of ablation study 2 (KG-only).
    
    Args:
        dataset_path: Path to the generated ablation dataset
        generator_stats: Optional generator statistics for complete analysis
    """
    try:
        with open(dataset_path, 'r', encoding='utf-8') as f:
            dataset = json.load(f)
        
        queries = dataset.get("queries", [])
        metadata = dataset.get("metadata", {})
        
        if not queries:
            return {"error": "No queries found in dataset"}
        
        # Get ablation-specific information
        ablation_info = metadata.get("ablation_study", {})
        generation_stats = metadata.get("generation_statistics", {})
        if generator_stats:
            generation_stats = generator_stats
        
        # Analyze accepted queries
        accepted_count = len(queries)
        
        # Calculate quality metrics for accepted queries
        quality_metrics = {
            "relevance_scores": [],
            "question_lengths": [],
            "answer_lengths": [],
            "triple_similarities": []
        }
        
        question_type_distribution = {}
        
        for query in queries:
            filtering_info = query.get("filtering_result", {})
            quality_metrics["relevance_scores"].append(filtering_info.get("relevance_score", 0.0))
            quality_metrics["question_lengths"].append(len(query["question"].split()))
            quality_metrics["answer_lengths"].append(len(query["answer"].split()))
            quality_metrics["triple_similarities"].append(filtering_info.get("triple_similarity", 0.0))
            
            # Count question types
            q_type = query.get("question_type", "unknown")
            question_type_distribution[q_type] = question_type_distribution.get(q_type, 0) + 1
        
        # Use generation statistics for complete picture
        total_attempts = generation_stats.get("total_qa_attempts", accepted_count)
        rejected_count = total_attempts - accepted_count
        
        # Calculate comprehensive statistics
        analysis_results = {
            "ablation_study_info": ablation_info,
            "total_qa_attempts": total_attempts,
            "final_dataset_queries": accepted_count,
            "decision_distribution": {
                "accepted_count": accepted_count,
                "rejected_count": rejected_count,
                "acceptance_rate": (accepted_count / total_attempts) * 100 if total_attempts > 0 else 0
            },
            "detailed_filtering_breakdown": generation_stats.get("filtering_decisions", {}),
            "question_type_distribution": question_type_distribution,
            "quality_statistics": {
                "avg_relevance_score": np.mean(quality_metrics["relevance_scores"]) if quality_metrics["relevance_scores"] else 0,
                "min_relevance_score": min(quality_metrics["relevance_scores"]) if quality_metrics["relevance_scores"] else 0,
                "max_relevance_score": max(quality_metrics["relevance_scores"]) if quality_metrics["relevance_scores"] else 0,
                "std_relevance_score": np.std(quality_metrics["relevance_scores"]) if quality_metrics["relevance_scores"] else 0,
                "avg_question_length": np.mean(quality_metrics["question_lengths"]) if quality_metrics["question_lengths"] else 0,
                "avg_answer_length": np.mean(quality_metrics["answer_lengths"]) if quality_metrics["answer_lengths"] else 0,
                "avg_triple_similarity": np.mean(quality_metrics["triple_similarities"]) if quality_metrics["triple_similarities"] else 0
            },
            "generation_efficiency": {
                "api_calls": generation_stats.get("total_api_calls", 0),
                "qa_per_api_call": accepted_count / generation_stats.get("total_api_calls", 1),
                "acceptance_rate": generation_stats.get("acceptance_rate", 0)
            },
            "ablation_specific_metrics": {
                "context_components_removed": ablation_info.get("removed_components", []),
                "kg_components_retained": ablation_info.get("retained_components", []),
                "kg_only_approach": True,
                "context_similarity_available": False
            }
        }
        
        return analysis_results
        
    except Exception as e:
        logger.error(f"Error analyzing KG ablation results: {e}")
        return {"error": str(e)}


def generate_kg_ablation_report(dataset_path: str, output_report_path: str, generator_stats: Optional[Dict] = None) -> None:
    """
    Generate a comprehensive report on ablation study 2 performance.
    
    Args:
        dataset_path: Path to the ablation study dataset
        output_report_path: Path to save the ablation report
        generator_stats: Optional generator statistics for complete analysis
    """
    try:
        # Analyze with complete statistics
        analysis = analyze_kg_ablation_results(dataset_path, generator_stats)
        
        if "error" in analysis:
            logger.error(f"Analysis failed: {analysis['error']}")
            return
        
        # Generate comprehensive report
        report = {
            "report_metadata": {
                "generation_date": pd.Timestamp.now().isoformat(),
                "dataset_analyzed": dataset_path,
                "report_type": "Ablation_Study_2_Analysis",
                "approach": "KG-only generation (no text context)"
            },
            "ablation_study_details": {
                "study_number": 2,
                "study_name": "KG Only (No Text Context)",
                "hypothesis": "Text context significantly improves QA generation quality beyond KG structure",
                "removed_components": analysis.get("ablation_study_info", {}).get("removed_components", []),
                "retained_components": analysis.get("ablation_study_info", {}).get("retained_components", []),
                "expected_impact": "Questions may lack nuanced understanding without contextual text information"
            },
            "executive_summary": {
                "total_qa_attempts": analysis.get("total_qa_attempts", 0),
                "final_questions_generated": analysis.get("final_dataset_queries", 0),
                "overall_acceptance_rate": f"{analysis.get('decision_distribution', {}).get('acceptance_rate', 0):.1f}%",
                "avg_relevance_score": f"{analysis.get('quality_statistics', {}).get('avg_relevance_score', 0):.3f}",
                "triple_similarity_only": True,
                "api_efficiency": f"{analysis.get('generation_efficiency', {}).get('qa_per_api_call', 0):.1f} QA pairs per API call"
            },
            "detailed_analysis": analysis,
            "comparison_baseline": {
                "note": "This is ablation study 2. Compare results with full one-shot model and study 1 to assess text context contribution.",
                "key_differences": [
                    "No text context used in generation",
                    "No context similarity scoring in relevance calculation",
                    "KG-only exemplar guidance",
                    "Simplified relevance scoring (triple similarity only)"
                ]
            },
            "expected_findings": [
                "Questions may be more abstract without contextual grounding",
                "Factual questions may be more entity-focused rather than content-specific",
                "Answer quality may depend heavily on KG completeness",
                "Generation may produce more formal, structured questions"
            ]
        }
        
        # Save report
        with open(output_report_path, 'w', encoding='utf-8') as f:
            json.dump(report, f, indent=2, ensure_ascii=False)
        
        logger.info(f"Ablation study 2 report generated: {output_report_path}")
        
    except Exception as e:
        logger.error(f"Error generating KG ablation report: {e}")


def main():
    """
    Main function for ablation study 2: KG-only QA generation.
    """
    logger.info("Starting ABLATION STUDY 2: KG-Only QA Generation")
    logger.info("Purpose: Evaluate contribution of text context vs. knowledge graph information to QA quality")
    
    try:
        # Get API key
        openai_api_key = input("Please enter your OpenAI API key: ").strip()
        
        if not openai_api_key or openai_api_key == "your-openai-api-key-here":
            logger.error("Please provide a valid OpenAI API key")
            return
        
        # Initialize generator for ablation study
        generator = KGOnlyQAGenerator(
            openai_api_key=openai_api_key,
            reset_duplicates_per_chunk=True,
            embedding_model='BAAI/bge-large-en-v1.5',
            max_embedding_cache=1000
        )

        # File paths - only triples file needed for KG-only study
        triples_file = input("Enter path to triples CSV file (or press Enter for 'Ontology_Guided_Triples.csv'): ").strip() or "Ontology_Guided_Triples.csv"
        output_file = input("Enter output file name (or press Enter for 'Ablation_2_kg_only_qa_dataset.json'): ").strip() or "Ablation_2_kg_only_qa_dataset.json"

        if not os.path.exists(triples_file):
            logger.error(f"Triples file not found: {triples_file}")
            logger.info("Please ensure your triples file exists and has the correct path")
            return

        # Load triples data only (no chunks file needed)
        logger.info("Loading triples data for KG-only ablation study")
        merged_data = generator.load_triples_data(triples_file)

        if not merged_data:
            logger.error("No data loaded. Please check your input files.")
            return
        
        # Get user input for number of chunks to process
        try:
            limit = input("Enter number of chunks to process (or press Enter for all): ").strip()
            limit = int(limit) if limit else None
        except ValueError:
            limit = None
            logger.info("Using default: processing all chunks")
        
        if limit is None:
            logger.info("Processing all chunks for ablation study")
        else:
            logger.info(f"Processing {limit} chunks for ablation study")
        
        # Create KG-only dataset
        logger.info(f"Creating Ablation Study 2 dataset (KG-only approach)")
        generator.create_kg_only_dataset(
            merged_grouped=merged_data,
            output_file=output_file,
            limit=limit
        )

        # Get final statistics from generator
        final_stats = generator.get_filtering_statistics()

        # Generate analysis report
        logger.info("Analyzing KG ablation study results")
        analysis_results = analyze_kg_ablation_results(output_file, final_stats)
        
        if "error" in analysis_results:
            logger.error(f"Analysis failed: {analysis_results['error']}")
        else:
            logger.info("Ablation Study 2 Results:")
            logger.info(f"  Total QA attempts: {analysis_results.get('total_qa_attempts', 0)}")
            logger.info(f"  Final questions generated: {analysis_results.get('final_dataset_queries', 0)}")
            logger.info(f"  Overall acceptance rate: {analysis_results.get('decision_distribution', {}).get('acceptance_rate', 0):.1f}%")
            logger.info(f"  Average relevance score (KG-only): {analysis_results.get('quality_statistics', {}).get('avg_relevance_score', 0):.3f}")

        # Generate comprehensive report
        report_file = "Ablation_2_kg_only_analysis_report.json"
        generate_kg_ablation_report(output_file, report_file, final_stats)
        
        # Display final statistics
        logger.info("\nFINAL Ablation Study 2 Statistics:")
        logger.info(f"Total API calls: {final_stats.get('total_api_calls', 0)}")
        logger.info(f"Total QA attempts: {final_stats.get('total_qa_attempts', 0)}")
        logger.info(f"Successful generations: {final_stats.get('successful_generations', 0)}")

        logger.info("\nFiltering Decision Breakdown:")
        for decision_type, count in final_stats["filtering_decisions"].items():
            if count > 0:
                percentage = (count / final_stats['total_qa_attempts']) * 100 if final_stats['total_qa_attempts'] > 0 else 0
                logger.info(f"  {decision_type}: {count} ({percentage:.1f}%)")
        
        if "acceptance_rate" in final_stats:
            logger.info(f"\nOverall Acceptance Rate: {final_stats['acceptance_rate']:.1f}%")
        
        # Validation check
        validation_status = final_stats.get("validation_status", "UNKNOWN")
        logger.info(f"Statistics validation: {validation_status}")
        
        logger.info(f"\nOutputs generated:")
        logger.info(f"  Ablation Dataset: {output_file}")
        logger.info(f"  Analysis Report: {report_file}")
        logger.info("\nABLATION STUDY 2 completed successfully!")
        logger.info("\nAblation Study 2 Features:")
        logger.info("    Removed: Text context and context similarity")
        logger.info("    Retained: Knowledge graph triples and triple similarity")
        logger.info("    Purpose: Evaluate text context contribution to QA generation quality")

    except KeyboardInterrupt:
        logger.info("\nProcess interrupted by user")
    except Exception as e:
        logger.error(f"Error in main execution: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        sys.exit(1)

if __name__ == "__main__":
    main()