In [None]:
import json
import pandas as pd
import time
from typing import Dict, List, Optional
import logging
from tqdm import tqdm
import re
import os
from datetime import datetime
from together import Together
import random

class ChunksOnlyLLMJudgeEvaluator:
    """
    Ablation Study 1: LLM Judge Evaluation System for Chunks-Only Model
    Evaluates QA pairs generated using only chunk-based retrieval without KG information.
    Together AI Version with Gemma-2-27b-it
    """
    
    def __init__(self, api_key: str, model_name: str = "google/gemma-2-27b-it"):
        """
        Initialize the Chunks-Only LLM Judge Evaluator.
        
        Args:
            api_key: Together AI API key
            model_name: Model identifier for the LLM judge
        """
        self.api_key = api_key
        self.model_name = model_name
        
        # Initialize Together AI client
        self.client = Together(api_key=api_key)
        
        # Setup logging
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
        
        self.evaluation_prompt = """You are a STRICT QA evaluator following precise scoring guidelines for a CHUNKS-ONLY model.

You will evaluate a model-generated question-answer pair on FOUR metrics using a 1-5 scale where:
5 = Excellent  
4 = Good  
3 = Fair  
2 = Poor  
1 = Very Poor

==============================
DETAILED SCORING CRITERIA
==============================

1. RELEVANCE (1-5): Does the question appropriately relate to the source chunks?
   5: Perfectly relevant to the chunks, clearly grounded in the text
   4: Mostly relevant, with minor off-topic elements  
   3: Addresses the main question but misses some important points from chunks
   2: Loosely related, with significant tangents or irrelevance to chunks
   1: Entirely irrelevant or unrelated to the source chunks

2. ACCURACY (1-5): Is the answer factually correct based on the source chunks?
   5: All facts are accurate and fully verifiable in the chunks
   4: Mostly accurate; contains only minor factual issues
   3: Some factual inconsistencies or assumptions beyond chunks
   2: Several factual errors that affect reliability
   1: Mostly inaccurate or misleading information

3. COMPLETENESS (1-5): Does the answer fully address the question using available chunks?
   5: Thorough and complete response utilizing chunk information effectively
   4: Covers most parts but misses minor aspects available in chunks
   3: Addresses main part, omits some key details from chunks
   2: Partial answer with significant gaps despite relevant chunk information
   1: Severely incomplete or ignores available chunk information

4. FLUENCY (1-5): Is the answer well-written and grammatically correct?
   5: Excellent grammar and clarity; highly readable
   4: Minor grammatical or structural issues
   3: Understandable, but contains noticeable language errors
   2: Somewhat unclear due to poor grammar or phrasing
   1: Difficult to read or understand

==============================
EVALUATION CONTEXT
==============================
This model uses ONLY text chunks for retrieval — no structured knowledge graph information.
Evaluate how well the model utilizes the available chunk information to answer questions.

==============================
INPUT
==============================
**Question:** {question}  
**Answer:** {answer}  
**Source Chunks:** {source_context}

JUST return your answer in this exact format:

Relevance: X  
Accuracy: X  
Completeness: X  
Fluency: X

Where X is a number from 1 to 5."""

    def get_source_context(self, item: Dict) -> str:
        """Extract source context from QA item."""
        # Check if source_context is nested under ground_truth
        if "ground_truth" in item and isinstance(item["ground_truth"], dict):
            return item["ground_truth"].get("source_context", "")
        
        # Fallback to root level
        return item.get("source_context", "No source context available")

    def call_llm_judge(self, question: str, answer: str, source_context: str, 
                      max_retries: int = 3) -> Optional[Dict[str, int]]:
        """
        Call the LLM judge to evaluate a QA pair with retry logic using Together AI.
        
        Args:
            question: The question to evaluate
            answer: The answer to evaluate
            source_context: Source context for the QA pair
            max_retries: Maximum number of retry attempts
            
        Returns:
            Dict with evaluation scores or None if failed
        """
        # Format the prompt (no KG triples needed for chunks-only model)
        prompt = self.evaluation_prompt.format(
            question=question,
            answer=answer,
            source_context=source_context
        )
        
        for attempt in range(max_retries):
            try:
                # Together AI API call
                response = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=[{"role": "user", "content": prompt}],
                    temperature=0.0,  # Deterministic for consistent scoring
                    max_tokens=150,   # Short response expected
                    top_p=0.8
                )
                
                if response and response.choices:
                    content = response.choices[0].message.content
                    
                    if content:
                        parsed_result = self.parse_evaluation_response(content)
                        if parsed_result:
                            return parsed_result
                    else:
                        self.logger.warning(f"Empty response content from {self.model_name}")
                else:
                    self.logger.warning(f"No response or choices from {self.model_name}")
                    
            except Exception as e:
                # Handle rate limiting and other errors
                if "rate limit" in str(e).lower() or "429" in str(e):
                    wait_time = 2 ** attempt  # Exponential backoff
                    self.logger.warning(f"Rate limit hit, waiting {wait_time}s before retry {attempt + 1}/{max_retries}")
                    time.sleep(wait_time)
                elif "timeout" in str(e).lower():
                    self.logger.warning(f"Timeout on attempt {attempt + 1}/{max_retries}")
                    if attempt < max_retries - 1:
                        time.sleep(2 ** attempt)
                else:
                    self.logger.error(f"API Error on attempt {attempt + 1}: {str(e)}")
                    if attempt < max_retries - 1:
                        time.sleep(2 ** attempt)
        
        return None

    def parse_evaluation_response(self, response: str) -> Optional[Dict[str, int]]:
        """
        Parse the LLM response to extract numerical scores for 4 metrics (no KG alignment).
        
        Expected formats:
        Relevance: 4
        Accuracy: 5/5  
        Completeness:3
        Fluency:  4
        """
        try:
            scores = {}
            lines = response.strip().split('\n')
            
            for line in lines:
                line = line.strip()
                if ':' in line:
                    parts = line.split(':', 1)
                    if len(parts) == 2:
                        metric = parts[0].strip()
                        score_text = parts[1].strip()
                        
                        # Use improved regex pattern
                        score_match = re.search(r':\s*([1-5])(?:/5)?', line)
                        if score_match:
                            score = int(score_match.group(1))
                            
                            # Normalize metric names (case-insensitive)
                            metric_lower = metric.lower()
                            if 'relevance' in metric_lower:
                                scores['Relevance'] = score
                            elif 'accuracy' in metric_lower:
                                scores['Accuracy'] = score
                            elif 'completeness' in metric_lower:
                                scores['Completeness'] = score
                            elif 'fluency' in metric_lower:
                                scores['Fluency'] = score
            
            # FIXED: If standard parsing fails, extract from reasoning text
            expected_metrics = ['Relevance', 'Accuracy', 'Completeness', 'Fluency']
            if len(scores) < 4:
                # Extract from reasoning patterns like "Relevance should be a 5"
                text_lower = response.lower()
                
                if 'Relevance' not in scores:
                    match = re.search(r'relevance.*?(?:should be|is).*?([1-5])', text_lower)
                    if match:
                        scores['Relevance'] = int(match.group(1))
                
                if 'Accuracy' not in scores:
                    match = re.search(r'accuracy.*?(?:should be|is).*?([1-5])', text_lower)
                    if match:
                        scores['Accuracy'] = int(match.group(1))
                
                if 'Completeness' not in scores:
                    match = re.search(r'completeness.*?(?:should be|is).*?([1-5])', text_lower)
                    if match:
                        scores['Completeness'] = int(match.group(1))
                
                if 'Fluency' not in scores:
                    match = re.search(r'fluency.*?(?:should be|is).*?([1-5])', text_lower)
                    if match:
                        scores['Fluency'] = int(match.group(1))
            
            # If we have partial scores, try to get the rest
            if len(scores) > 0 and len(scores) < 4:
                self.logger.warning(f"Got partial response with {len(scores)} scores: {scores}")
                # Fill missing scores with average of existing ones
                if len(scores) >= 2:
                    avg_score = round(sum(scores.values()) / len(scores))
                    expected_metrics = ['Relevance', 'Accuracy', 'Completeness', 'Fluency']
                    for metric in expected_metrics:
                        if metric not in scores:
                            scores[metric] = avg_score
                            self.logger.info(f"Filled missing {metric} with average {avg_score}")
                    return scores
            
            # Validate we have all 4 scores (no KG_Alignment for chunks-only)
            if all(metric in scores for metric in expected_metrics):
                return scores
            else:
                self.logger.warning(f"Missing metrics in response: {response}")
                return self._parse_with_fallback(response)
                
        except Exception as e:
            self.logger.error(f"Failed to parse response: {response}. Error: {str(e)}")
            return None
    
    def _parse_with_fallback(self, response: str) -> Optional[Dict[str, int]]:
        """Fallback parsing with multiple regex patterns for 4 metrics."""
        patterns = [
            r':\s*([1-5])(?:/5)?',          # Improved pattern
            r'\b([1-5])\b',                 # Original pattern
            r'([1-5])\s*(?:out of 5|/5)?'   # Alternative pattern
        ]
        
        for pattern in patterns:
            try:
                scores = {}
                lines = response.strip().split('\n')
                
                for line in lines:
                    if ':' in line:
                        parts = line.split(':', 1)
                        if len(parts) == 2:
                            metric = parts[0].strip()
                            score_text = parts[1].strip()
                            
                            score_match = re.search(pattern, score_text)
                            if score_match:
                                score = int(score_match.group(1))
                                
                                # Normalize metric names (only 4 metrics for chunks-only)
                                metric_lower = metric.lower()
                                if 'relevance' in metric_lower:
                                    scores['Relevance'] = score
                                elif 'accuracy' in metric_lower:
                                    scores['Accuracy'] = score
                                elif 'completeness' in metric_lower:
                                    scores['Completeness'] = score
                                elif 'fluency' in metric_lower:
                                    scores['Fluency'] = score
                
                # Check if this pattern worked (only 4 metrics)
                expected_metrics = ['Relevance', 'Accuracy', 'Completeness', 'Fluency']
                if all(metric in scores for metric in expected_metrics):
                    return scores
                    
            except Exception:
                continue
        
        return None

    def validate_evaluation_scores(self, evaluation: Dict[str, int], qa_id: str) -> Dict[str, int]:
        """
        Validate and clamp evaluation scores to valid range [1, 5].
        """
        validated = {}
        issues = []
        
        for metric, score in evaluation.items():
            original_score = score
            
            # Handle non-numeric values
            if not isinstance(score, (int, float)):
                try:
                    score = int(score)
                except (ValueError, TypeError):
                    score = 3  # Default to middle score
                    issues.append(f"{metric}: non-numeric '{original_score}' → {score}")
            
            # Clamp to valid range
            clamped_score = min(max(int(score), 1), 5)
            validated[metric] = clamped_score
            
            # Log if clamped
            if clamped_score != original_score:
                issues.append(f"{metric}: {original_score} → {clamped_score}")
        
        # Log any issues
        if issues:
            self.logger.warning(f"Score validation issues for {qa_id}: {'; '.join(issues)}")
        
        return validated

    def load_qa_dataset(self, file_path: str) -> List[Dict]:
        """Load QA dataset from JSON file."""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            queries = data.get('queries', [])
            self.logger.info(f"Loaded {len(queries)} QA pairs from {file_path}")
            return queries
            
        except Exception as e:
            self.logger.error(f"Failed to load dataset: {str(e)}")
            return []

    def save_checkpoint(self, results: List[Dict], checkpoint_path: str):
        """Save evaluation results as checkpoint."""
        try:
            df = pd.DataFrame(results)
            df.to_csv(checkpoint_path, index=False)
            self.logger.info(f"Checkpoint saved: {len(results)} results to {checkpoint_path}")
        except Exception as e:
            self.logger.error(f"Failed to save checkpoint: {str(e)}")

    def load_checkpoint(self, checkpoint_path: str) -> List[Dict]:
        """Load evaluation results from checkpoint."""
        try:
            if os.path.exists(checkpoint_path):
                df = pd.read_csv(checkpoint_path)
                results = df.to_dict('records')
                self.logger.info(f"Loaded checkpoint: {len(results)} results from {checkpoint_path}")
                return results
            return []
        except Exception as e:
            self.logger.error(f"Failed to load checkpoint: {str(e)}")
            return []

    def evaluate_dataset(self, dataset_path: str, output_path: str, 
                        sample_size: Optional[int] = None, 
                        delay_seconds: float = 1.0,
                        checkpoint_interval: int = 50) -> pd.DataFrame:
        """
        Evaluate QA dataset using chunks-only LLM judge with checkpointing.
        
        Args:
            dataset_path: Path to the QA dataset JSON file
            output_path: Path to save evaluation results
            sample_size: Number of samples to evaluate (None for all)
            delay_seconds: Delay between API calls to avoid rate limits
            checkpoint_interval: Save checkpoint every N evaluations
            
        Returns:
            DataFrame with evaluation results
        """
        # Load dataset
        qa_items = self.load_qa_dataset(dataset_path)
        if not qa_items:
            self.logger.error("No QA items loaded. Exiting.")
            return pd.DataFrame()
        
        # Sample subset if requested
        if sample_size and sample_size < len(qa_items):
            random.seed(42)  # For reproducibility
            qa_items = random.sample(qa_items, sample_size)
            self.logger.info(f"Sampling {sample_size} items for evaluation")
        
        # Setup checkpoint
        checkpoint_path = f"{output_path}.checkpoint"
        results = self.load_checkpoint(checkpoint_path)
        
        # Track processed items
        processed_ids = {result['qa_id'] for result in results} if results else set()
        
        # Filter unprocessed items - FIXED indexing
        remaining_items = []
        for i, item in enumerate(qa_items):
            item_id = item.get('id', f'item_{i}')
            if item_id not in processed_ids:
                remaining_items.append(item)
        
        if processed_ids:
            self.logger.info(f"Resuming from checkpoint: {len(results)} completed, {len(remaining_items)} remaining")
        
        failed_evaluations = 0
        
        # Progress bar with correct total and initial values
        total_items = len(qa_items)
        completed_items = len(results)
        pbar = tqdm(
            remaining_items, 
            desc="Evaluating Chunks-Only QA pairs",
            total=total_items,
            initial=completed_items,
            unit="items"
        )
        
        for i, item in enumerate(remaining_items):
            try:
                # Extract data from item - FIXED ID generation
                original_index = next((i for i, orig_item in enumerate(qa_items) if orig_item is item), len(qa_items))
                qa_id = item.get('id', f'item_{original_index}')
                question = item.get('question', '')
                answer = item.get('answer', '')
                question_type = item.get('question_type', 'unknown')
                
                # Extract source context (chunks only)
                source_context = self.get_source_context(item)
                
                # Skip if essential data is missing
                if not question or not answer:
                    self.logger.warning(f"Skipping item {qa_id}: missing question or answer")
                    pbar.update(1)
                    continue
                
                # Call LLM judge (no KG triples needed)
                evaluation = self.call_llm_judge(question, answer, source_context)
                
                if evaluation:
                    evaluation = self.validate_evaluation_scores(evaluation, qa_id)
                    
                    # Store results (only 4 metrics for chunks-only)
                    result = {
                        'qa_id': qa_id,
                        'question_type': question_type,
                        'question': question,
                        'answer': answer,
                        'Relevance': evaluation['Relevance'],
                        'Accuracy': evaluation['Accuracy'],
                        'Completeness': evaluation['Completeness'],
                        'Fluency': evaluation['Fluency'],
                        'Overall_Score': sum(evaluation.values()) / len(evaluation)
                    }
                    results.append(result)
                    
                    # Update progress bar with detailed status
                    pbar.set_postfix({
                        'Completed': len(results),
                        'Failed': failed_evaluations,
                        'Success_Rate': f"{len(results)/(len(results)+failed_evaluations)*100:.1f}%",
                        'Last_Score': f"{result['Overall_Score']:.1f}"
                    })
                    
                    # Save checkpoint
                    if len(results) % checkpoint_interval == 0:
                        self.save_checkpoint(results, checkpoint_path)
                    
                else:
                    failed_evaluations += 1
                    self.logger.warning(f"Failed to evaluate item {qa_id}")
                
                # Update progress bar
                pbar.update(1)
                
                # Rate limiting
                if delay_seconds > 0:
                    time.sleep(delay_seconds)
                    
            except Exception as e:
                failed_evaluations += 1
                self.logger.error(f"Error processing item {qa_id}: {str(e)}")
                pbar.update(1)
                continue
        
        pbar.close()
        
        # Final save
        if results:
            df = pd.DataFrame(results)
            df.to_csv(output_path, index=False)
            
            # Clean up checkpoint
            if os.path.exists(checkpoint_path):
                os.remove(checkpoint_path)
            
            # Print summary statistics
            self.print_evaluation_summary(df, failed_evaluations)
            
            return df
        else:
            self.logger.error("No successful evaluations completed")
            return pd.DataFrame()

    def print_evaluation_summary(self, df: pd.DataFrame, failed_count: int):
        """Print comprehensive summary statistics of the chunks-only evaluation."""
        print(f"\n{'='*70}")
        print(f" ABLATION STUDY 1: CHUNKS-ONLY MODEL EVALUATION SUMMARY")
        print(f"{'='*70}")
        
        print(f" Evaluation Statistics:")
        print(f"   Total Evaluated: {len(df)}")
        print(f"   Failed Evaluations: {failed_count}")
        print(f"   Success Rate: {len(df)/(len(df)+failed_count)*100:.1f}%")
        print(f"   Evaluation Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"   Model Type: CHUNKS-ONLY (No Knowledge Graph)")
        print(f"   Judge Model: {self.model_name}")
        
        print(f"\n Average Scores by Metric:")
        metrics = ['Relevance', 'Accuracy', 'Completeness', 'Fluency']
        for metric in metrics:
            mean_score = df[metric].mean()
            std_score = df[metric].std()
            min_score = df[metric].min()
            max_score = df[metric].max()
            print(f"   {metric}: {mean_score:.2f} ± {std_score:.2f} (range: {min_score}-{max_score})")
        
        print(f"\n Overall Performance:")
        print(f"   Mean Overall Score: {df['Overall_Score'].mean():.2f}")
        print(f"   Median Overall Score: {df['Overall_Score'].median():.2f}")
        print(f"   Best Score: {df['Overall_Score'].max():.2f}")
        print(f"   Worst Score: {df['Overall_Score'].min():.2f}")
        
        # Score distribution
        print(f"\n Score Distribution:")
        score_ranges = [(1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, 5.0)]
        for min_score, max_score in score_ranges:
            count = len(df[(df['Overall_Score'] >= min_score) & (df['Overall_Score'] < max_score)])
            percentage = (count / len(df)) * 100
            print(f"   {min_score:.1f}-{max_score:.1f}: {count} ({percentage:.1f}%)")
        
        # Perfect scores
        perfect_scores = len(df[df['Overall_Score'] == 5.0])
        print(f"   Perfect (5.0): {perfect_scores} ({(perfect_scores/len(df))*100:.1f}%)")
        
        if 'question_type' in df.columns:
            print(f"\n Performance by Question Type:")
            type_summary = df.groupby('question_type').agg({
                'Overall_Score': ['mean', 'std', 'count'],
                'Relevance': 'mean',
                'Accuracy': 'mean',
                'Completeness': 'mean',
                'Fluency': 'mean'
            }).round(3)
            
            for qtype in type_summary.index:
                stats = type_summary.loc[qtype]
                mean_score = stats[('Overall_Score', 'mean')]
                std_score = stats[('Overall_Score', 'std')]
                count = stats[('Overall_Score', 'count')]
                print(f"   {qtype}: {mean_score:.2f} ± {std_score:.2f} (n={count})")
        
        print(f"\n Ablation Study Notes:")
        print(f"   - This evaluation focuses on chunk-based retrieval only")
        print(f"   - No Knowledge Graph information was used")
        print(f"   - Evaluation uses 4 metrics (no KG Alignment)")
        print(f"   - Results can be compared with full model performance")

def run_chunks_only_evaluation():
    """Run evaluation for Ablation Study 1: Chunks-Only Model."""
    
    # Configuration - Together AI
    API_KEY = ""  # Your Together AI API key
    DATASET_PATH = "Ablation_1_chunks_only_qa_dataset.json"  # Update with your chunks-only dataset path
    OUTPUT_PATH = "Ablation_Study_1_Chunks_Only_Gemma2_evaluation_results.csv"
    
    # Initialize chunks-only evaluator with Gemma-2-27b-it
    evaluator = ChunksOnlyLLMJudgeEvaluator(
        api_key=API_KEY,
        model_name="google/gemma-2-27b-it"
    )
    
    print("=" * 70)
    print(" ABLATION STUDY 1: CHUNKS-ONLY MODEL EVALUATION")
    print("=" * 70)
    print(" This evaluation will assess QA pairs generated using ONLY chunks")
    print(" (no Knowledge Graph information)")
    print(" ")
    print(" Evaluation Details:")
    print("   - Model: Chunks-Only Retrieval")
    print(f"   - Judge Model: {evaluator.model_name}")
    print("   - API Provider: Together AI")
    print("   - Metrics: 4 (Relevance, Accuracy, Completeness, Fluency)")
    print("   - Scoring: Strict 1-5 scale")
    print("   - Checkpoints: Every 50 evaluations")
    print("   - Can resume if interrupted")
    print(" ")
    
    # Confirm before starting
    response = input("\nProceed with chunks-only evaluation? (y/n): ").strip().lower()
    if response != 'y':
        print(" Evaluation cancelled.")
        return
    
    start_time = time.time()
    
    # Run evaluation on chunks-only dataset
    results_df = evaluator.evaluate_dataset(
        dataset_path=DATASET_PATH,
        output_path=OUTPUT_PATH,
        sample_size=None,  # Process all items
        delay_seconds=1.0,  # 1 second delay between requests
        checkpoint_interval=50  # Save every 50 evaluations
    )
    
    end_time = time.time()
    duration = end_time - start_time
    
    if not results_df.empty:
        print(f"\n CHUNKS-ONLY EVALUATION COMPLETED!")
        print(f"  Total time: {duration/60:.1f} minutes")
        print(f" Evaluated: {len(results_df)} QA pairs")
        print(f" Results saved to: {OUTPUT_PATH}")
        print(f" Judge Model: {evaluator.model_name}")
        print(f" API Provider: Together AI")
        
        # Additional analysis
        print(f"\n Processing Statistics:")
        print(f"   Average processing time: {duration/len(results_df):.2f} seconds per QA pair")
        print(f"   Items per minute: {len(results_df)/(duration/60):.1f}")
        
        # Score distribution preview
        print(f"\n Chunks-Only Performance Preview:")
        print(f"   Mean Overall Score: {results_df['Overall_Score'].mean():.2f}")
        print(f"   Score Range: {results_df['Overall_Score'].min():.2f} - {results_df['Overall_Score'].max():.2f}")
        print(f"   Standard Deviation: {results_df['Overall_Score'].std():.2f}")
        
    else:
        print(" Chunks-only evaluation failed. Check logs for details.")

if __name__ == "__main__":
    run_chunks_only_evaluation()