In [None]:
import os
import json
import re
import time
import numpy as np
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import google.generativeai as genai
from tqdm import tqdm  # For progress bars
import argparse  # For command-line arguments

# Download NLTK data if not present
nltk.download('punkt', quiet=True)

# Initialize Gemini model
def initialize_gemini_model():
    """Initialize the Gemini model."""
    model_name = 'gemini-1.5-flash'
    api_key = 'API_KEY'  # Replace with your actual API key
    if not api_key:
        raise ValueError("GEMINI_API_KEY environment variable is not set.")
    
    genai.configure(api_key=api_key)
    print("Gemini Model Initialized")
    return genai.GenerativeModel(model_name)

# Normalize text for fair comparison
def normalize_text(text):
    """Normalize text for fair comparison."""
    if not text:
        return ""
    return str(text).lower().strip()

# Extract JSON content from generated text
def extract_json_from_text(text):
    """Extract JSON content from generated text."""
    try:
        # Find JSON pattern in the text
        json_pattern = r'\{[^{}]*\}'
        matches = re.findall(json_pattern, text)
        if matches:
            # Try to parse each match until valid JSON is found
            for match in reversed(matches):
                try:
                    return json.loads(match)
                except json.JSONDecodeError:
                    continue
        return {}
    except Exception as e:
        print(f"Error extracting JSON: {e}")
        return {}

# Calculate BLEU and ROUGE scores
def calculate_metrics(generated_text, reference_text):
    """Calculate BLEU and ROUGE scores between generated and reference texts."""
    try:
        # Normalize texts
        gen_normalized = normalize_text(generated_text)
        ref_normalized = normalize_text(reference_text)
        
        # Calculate ROUGE scores
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        rouge_scores = scorer.score(ref_normalized, gen_normalized)
        
        # Prepare for BLEU calculation
        reference_tokens = [nltk.word_tokenize(ref_normalized)]
        generated_tokens = nltk.word_tokenize(gen_normalized)
        
        # Calculate BLEU with smoothing
        smooth = SmoothingFunction()
        bleu1 = sentence_bleu(reference_tokens, generated_tokens, 
                             weights=(1, 0, 0, 0),
                             smoothing_function=smooth.method1)
        
        bleu4 = sentence_bleu(reference_tokens, generated_tokens, 
                             weights=(0.25, 0.25, 0.25, 0.25),
                             smoothing_function=smooth.method1)
        
        return {
            'bleu1': bleu1,
            'bleu4': bleu4,
            'rouge1': rouge_scores['rouge1'].fmeasure,
            'rouge2': rouge_scores['rouge2'].fmeasure, 
            'rougeL': rouge_scores['rougeL'].fmeasure
        }
    except Exception as e:
        print(f"Error calculating metrics: {e}")
        return {
            'bleu1': 0.0, 'bleu4': 0.0,
            'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0
        }

# Calculate exact match rate for JSON objects
def calculate_exact_match_rate(generated_json, reference_json):
    """Calculate exact match rate for JSON objects."""
    try:
        if not generated_json or not reference_json:
            return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
            
        # Extract missing points arrays
        gen_points = generated_json.get('pointsMissed', [])
        ref_points = reference_json.get('pointsMissed', [])
        
        # Flatten and normalize points
        def flatten_and_normalize(points):
            flat_list = []
            for item in points:
                if isinstance(item, list):
                    flat_list.extend(flatten_and_normalize(item))
                else:
                    flat_list.append(normalize_text(item))
            return flat_list
        
        gen_points = flatten_and_normalize(gen_points)
        ref_points = flatten_and_normalize(ref_points)
        
        # Calculate how many points match exactly
        matches = sum(1 for p in gen_points if p in ref_points)
        
        # Calculate precision, recall, and F1
        precision = matches / len(gen_points) if gen_points else 0
        recall = matches / len(ref_points) if ref_points else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        return {
            'precision': precision,
            'recall': recall,
            'f1': f1
        }
    except Exception as e:
        print(f"Error calculating exact match: {e}")
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

# Generate a response using the Gemini model
def generate_with_gemini(model, prompt, max_retries=3):
    """Generate a response using the Gemini model with retries."""
    for attempt in range(max_retries):
        try:
            response = model.generate_content(prompt)
            if response.text:
                return response.text
            else:
                print(f"Attempt {attempt + 1}: Empty response from Gemini.")
        except Exception as e:
            print(f"Attempt {attempt + 1}: Error generating with Gemini: {e}")
        time.sleep(2)  # Wait before retrying
    return ""

# Evaluate using Gemini model
def evaluate_with_gemini(data, model=None):
    """Evaluate using Gemini model instead of the previous model."""
    # Initialize the Gemini model if not provided
    if model is None:
        model = initialize_gemini_model()
    
    # Process evaluation data
    total_metrics = {
        'bleu1': [], 'bleu4': [],
        'rouge1': [], 'rouge2': [], 'rougeL': [],
        'precision': [], 'recall': [], 'f1': []
    }
    
    for idx, item in tqdm(enumerate(data), total=len(data), desc="Evaluating"):
        try:
            # Prepare the prompt
            mes = '''
Ultra-Strict Expert Evaluation Rules:youare gove a answer to be covered json list blocks you have tand the answrr cointains the points related to points to be covered you are suppoed to select which points list blocks is missing you have to select only one list block ss the aswer and nothing feom other list blocks and by any case do not mix the asnwers from different lists under any s

Example-
answer- P1 p3 p4 p5
list of points to be covered - {[p1,p2,p3],[p4,p5],[p6,p7]}
points missed - {[p6,p7]}

don't do like this -
reference_output -->{"pointsMissed": ["All is well", "Ending the conversation"]}
Generated outut --> Points missed: [['All is well', 'Informal response', 'Ending the conversation']]
this is wrong the generated output should be  ["All is well", "Ending the conversation"] nothing else even if u get the answers wrong its fine but don't mixup the solution  
''' 

            ans = mes + item['conversations'][1]['content']
            reference_output = item['conversations'][2]['content']
            
            print(f'reference Text --> {reference_output}')
            # Generate with Gemini
            generated_text = generate_with_gemini(model, ans)
            print(f'Generated_text --> {generated_text}')
            # Extract JSON from both texts
            reference_json = extract_json_from_text(reference_output)
            generated_json = extract_json_from_text(generated_text)
            
            # Calculate text-level metrics
            text_metrics = calculate_metrics(generated_text, reference_output)
            
            # Calculate JSON-level metrics
            json_metrics = calculate_exact_match_rate(generated_json, reference_json)
            
            # Accumulate metrics
            for key in text_metrics:
                total_metrics[key].append(text_metrics[key])
            for key in json_metrics:
                total_metrics[key].append(json_metrics[key])
                
            # Add a small delay to avoid rate limiting
            time.sleep(1)
                
        except Exception as e:
            print(f"Error processing example {idx + 1}: {e}")
            continue
    
    # Calculate and display average metrics
    print("\n===== EVALUATION RESULTS =====")
    print("\nAverage Metrics:")
    for key in total_metrics:
        if total_metrics[key]:
            avg = np.mean(total_metrics[key])
            print(f"{key}: {avg:.4f}")
    
    print("\nNumber of examples evaluated:", len(total_metrics['bleu1']))
    
    return total_metrics

# Load data in batches
def load_data_in_batches(data, batch_size):
    """Yield data in batches."""
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

# Batch evaluation
def batch_evaluate_data():
    """Load data in batches and evaluate to manage memory."""
    # Load data from JSON file
    with open('new_updated_examples.json', 'r') as f:
        data = json.load(f)
    
    # Initialize Gemini model
    model = initialize_gemini_model()
    
    # Process in batches
    batch_size = 10
    all_metrics = {
        'bleu1': [], 'bleu4': [],
        'rouge1': [], 'rouge2': [], 'rougeL': [],
        'precision': [], 'recall': [], 'f1': []
    }
    
    for batch in load_data_in_batches(data, batch_size):
        print(f"\nProcessing batch...")
        
        # Evaluate batch
        batch_metrics = evaluate_with_gemini(batch, model)
        
        # Accumulate metrics
        for key in batch_metrics:
            all_metrics[key].extend(batch_metrics[key])
    
    # Final aggregate results
    print("\n===== FINAL EVALUATION RESULTS =====")
    print("\nOverall Average Metrics:")
    for key in all_metrics:
        if all_metrics[key]:
            avg = np.mean(all_metrics[key])
            print(f"{key}: {avg:.4f}")
    
    print("\nTotal examples evaluated:", len(all_metrics['bleu1']))
    
    return all_metrics

# Evaluate a single prompt-response pair
def evaluate_single_pair():
    """Evaluate a single prompt-response pair for quick testing."""
    # Initialize Gemini model
    model = initialize_gemini_model()
    
    # Test prompt
    prompt = """Act like an expert evaluator, where given an answer and a list of points to be covered in the answer you give a json of missing points
Answer: Yeah. So the cost of product is $5. 
List of points to be covered: [['the cost of the product is 5', 'the product cost is 5', 'cost is 5 ', 'its cost is 5 ', 'just 5 ']]"""
    
    # Expected reference output
    reference = """{"pointsMissed": ["the cost of the product is 5", "the product cost is 5", "cost is 5 ", "its cost is 5 ", "just 5 "]}"""
    
    # Generate with Gemini
    generated = generate_with_gemini(model, prompt)
    print("\nPrompt:", prompt)
    print("\nReference Output:", reference)
    print("\nGenerated Output:", generated)
    
    # Calculate metrics
    text_metrics = calculate_metrics(generated, reference)
    
    # Extract JSON
    reference_json = extract_json_from_text(reference)
    generated_json = extract_json_from_text(generated)
    json_metrics = calculate_exact_match_rate(generated_json, reference_json)
    
    # Print results
    print("\n===== SINGLE PAIR EVALUATION RESULTS =====")
    print("\nText Metrics:")
    for key, value in text_metrics.items():
        print(f"{key}: {value:.4f}")
    
    print("\nJSON Content Metrics:")
    for key, value in json_metrics.items():
        print(f"{key}: {value:.4f}")
    
    return {**text_metrics, **json_metrics}

# Main function
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate Gemini model.")
    parser.add_argument('--mode', choices=['single', 'batch'], default='batch', help="Evaluation mode: single or batch.")
    args, unknown = parser.parse_known_args()  # Ignore unrecognized arguments
    
    if args.mode == 'single':
        evaluate_single_pair()
    else:
        batch_evaluate_data()