In [24]:
import json
import csv
from collections import defaultdict

def load_predictions(predictions_filepath):
    """Load predictions from model output."""
    with open(predictions_filepath, 'r', encoding='utf-8') as f:
        predictions_data = json.load(f)
    return predictions_data


def load_ground_truth(ground_truth_filepath):
    """Load ground truth from jsonl file."""
    ground_truth_list = []
    review_texts = []
    
    with open(ground_truth_filepath, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line.strip())
            review_texts.append(data.get('text', ''))
            ground_truth_list.append(data.get('labels', []))
    
    return ground_truth_list, review_texts


def match_by_text(predictions_data, ground_truth_list, review_texts):
    """Match predictions to ground truth by text content."""
    text_to_gt = {}
    for i, (gt_labels, text) in enumerate(zip(ground_truth_list, review_texts)):
        text_to_gt[text] = gt_labels
    
    matched_predictions = []
    matched_ground_truth = []
    matched_texts = []
    unmatched_count = 0
    
    for pred_item in predictions_data:
        pred_text = pred_item.get('text', '')
        if pred_text in text_to_gt:
            pred_labels = pred_item.get('labels', pred_item.get('predictions', []))
            gt_labels = text_to_gt[pred_text]
            
            matched_predictions.append(pred_labels)
            matched_ground_truth.append(gt_labels)
            matched_texts.append(pred_text)
        else:
            unmatched_count += 1
    
    if unmatched_count > 0:
        print(f"Warning: {unmatched_count} predictions could not be matched to ground truth")
    
    return matched_predictions, matched_ground_truth, matched_texts


def extract_component_values(labels, component):
    """Extract values for a specific component."""
    values = []
    for label in labels:
        if isinstance(label, dict):
            if component == 'aspect':
                value = label.get('aspect', 'NULL')
            elif component == 'opinion':
                value = label.get('opinion', 'NULL')
            elif component == 'category':
                value = label.get('category', 'NULL#NULL')
            elif component == 'polarity':
                value = label.get('polarity', 'neutral')
            else:
                continue
                
            if value and value not in ['NULL', 'NULL#NULL', '']:
                values.append(value)
    
    return values


def calculate_intersection_max_accuracy(pred_values, true_values):
    """Calculate intersection/max accuracy for components."""
    pred_set = set(pred_values)
    true_set = set(true_values)
    
    if len(true_set) == 0 and len(pred_set) == 0:
        return 1.0, 0, 0, 0
    elif len(true_set) == 0 or len(pred_set) == 0:
        return 0.0, 0, len(pred_set), len(true_set)
    else:
        intersection = len(pred_set.intersection(true_set))
        max_size = max(len(pred_set), len(true_set))
        return intersection / max_size, intersection, len(pred_set), len(true_set)


def evaluate_polarity_given_match(pred_labels, gt_labels, match_component):
    """
    Evaluate polarity accuracy given that a specific component matches.
    Returns: (correct_polarities, total_matched_items)
    """
    # Convert labels to dictionaries for easier lookup
    pred_dict = {}
    for label in pred_labels:
        if isinstance(label, dict):
            key = label.get(match_component, 'NULL')
            if key and key not in ['NULL', 'NULL#NULL', '']:
                if key not in pred_dict:
                    pred_dict[key] = []
                pred_dict[key].append(label.get('polarity', 'neutral'))
    
    gt_dict = {}
    for label in gt_labels:
        if isinstance(label, dict):
            key = label.get(match_component, 'NULL')
            if key and key not in ['NULL', 'NULL#NULL', '']:
                if key not in gt_dict:
                    gt_dict[key] = []
                gt_dict[key].append(label.get('polarity', 'neutral'))
    
    # Count correct polarities for matched components
    correct_polarities = 0
    total_matched = 0
    
    for key in pred_dict:
        if key in gt_dict:
            # Component matched, now check polarity
            pred_polarities = pred_dict[key]
            gt_polarities = gt_dict[key]
            
            # For each ground truth polarity, check if it exists in predictions
            for gt_pol in gt_polarities:
                total_matched += 1
                if gt_pol in pred_polarities:
                    correct_polarities += 1
    
    return correct_polarities, total_matched


def comprehensive_evaluation(predictions_list, ground_truth_list, review_texts):
    """Run comprehensive ABSA evaluation."""
    # Initialize counters
    total_reviews = len(predictions_list)
    
    # Metrics for all components including polarity (intersection/max)
    component_scores = {
        'aspect': {'sum_accuracy': 0, 'total_correct': 0, 'total_predicted': 0, 'total_ground_truth': 0},
        'opinion': {'sum_accuracy': 0, 'total_correct': 0, 'total_predicted': 0, 'total_ground_truth': 0},
        'category': {'sum_accuracy': 0, 'total_correct': 0, 'total_predicted': 0, 'total_ground_truth': 0},
        'polarity': {'sum_accuracy': 0, 'total_correct': 0, 'total_predicted': 0, 'total_ground_truth': 0}
    }
    
    # Conditional polarity metrics
    polarity_conditional_scores = {
        'given_aspect_match': {'correct': 0, 'total': 0},
        'given_opinion_match': {'correct': 0, 'total': 0},
        'given_category_match': {'correct': 0, 'total': 0}
    }
    
    # Per-review details
    per_review_results = []
    mixed_polarity_count = 0
    
    for idx, (pred_labels, gt_labels) in enumerate(zip(predictions_list, ground_truth_list)):
        review_result = {
            'index': idx,
            'text': review_texts[idx][:100] + '...' if len(review_texts[idx]) > 100 else review_texts[idx],
            'text_length': len(review_texts[idx])
        }
        
        # Extract ground truth categories and check mixed polarity
        gt_categories = []
        polarities = set()
        for label in gt_labels:
            if isinstance(label, dict):
                category = label.get('category', 'NULL#NULL')
                if category and category not in ['NULL', 'NULL#NULL', '']:
                    gt_categories.append(category)
                
                polarity = label.get('polarity', 'neutral')
                if polarity and polarity.lower() not in ['null', 'neutral']:
                    polarities.add(polarity.lower())
        
        review_result['ground_truth_categories'] = '|'.join(gt_categories)
        review_result['has_mixed_polarity'] = len(polarities) > 1
        
        if review_result['has_mixed_polarity']:
            mixed_polarity_count += 1
        
        # Evaluate all components with intersection/max (including polarity)
        for component in ['aspect', 'opinion', 'category', 'polarity']:
            pred_values = extract_component_values(pred_labels, component)
            true_values = extract_component_values(gt_labels, component)
            
            accuracy, correct, pred_count, true_count = calculate_intersection_max_accuracy(pred_values, true_values)
            
            component_scores[component]['sum_accuracy'] += accuracy
            component_scores[component]['total_correct'] += correct
            component_scores[component]['total_predicted'] += pred_count
            component_scores[component]['total_ground_truth'] += true_count
            
            review_result[f'{component}_accuracy'] = accuracy
            review_result[f'{component}_pred'] = pred_count
            review_result[f'{component}_true'] = true_count
            review_result[f'{component}_correct'] = correct
        
        # Evaluate conditional polarity
        for match_type, component in [('given_aspect_match', 'aspect'), 
                                      ('given_opinion_match', 'opinion'), 
                                      ('given_category_match', 'category')]:
            correct, total = evaluate_polarity_given_match(pred_labels, gt_labels, component)
            polarity_conditional_scores[match_type]['correct'] += correct
            polarity_conditional_scores[match_type]['total'] += total
            
            review_result[f'polarity_{match_type}_correct'] = correct
            review_result[f'polarity_{match_type}_total'] = total
        
        per_review_results.append(review_result)
    
    # Calculate final metrics
    final_results = {
        'total_reviews': total_reviews,
        'reviews_with_mixed_polarity': mixed_polarity_count,
        'percentage_mixed_polarity': (mixed_polarity_count / total_reviews * 100) if total_reviews > 0 else 0,
        'component_metrics': {},
        'polarity_conditional_metrics': {}
    }
    
    # Component metrics (all including polarity)
    for component in ['aspect', 'opinion', 'category', 'polarity']:
        scores = component_scores[component]
        avg_accuracy = scores['sum_accuracy'] / total_reviews if total_reviews > 0 else 0
        
        # Calculate precision, recall, F1
        precision = scores['total_correct'] / scores['total_predicted'] if scores['total_predicted'] > 0 else 0
        recall = scores['total_correct'] / scores['total_ground_truth'] if scores['total_ground_truth'] > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        final_results['component_metrics'][component] = {
            'intersection_max_accuracy': avg_accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'total_correct': scores['total_correct'],
            'total_predicted': scores['total_predicted'],
            'total_ground_truth': scores['total_ground_truth']
        }
    
    # Conditional polarity metrics
    for match_type in ['given_aspect_match', 'given_opinion_match', 'given_category_match']:
        scores = polarity_conditional_scores[match_type]
        accuracy = scores['correct'] / scores['total'] if scores['total'] > 0 else 0
        
        final_results['polarity_conditional_metrics'][match_type] = {
            'accuracy': accuracy,
            'correct': scores['correct'],
            'total': scores['total']
        }
    
    return final_results, per_review_results


def print_results(results):
    """Print evaluation results in a clear format."""
    print("=" * 70)
    print("CATEGORY DETECTION EVALUATION RESULTS")
    print("=" * 70)
    
    print(f"\nDataset Overview:")
    print(f"  Total Reviews: {results['total_reviews']}")
    print(f"  Reviews with Mixed Polarity: {results['reviews_with_mixed_polarity']} ({results['percentage_mixed_polarity']:.1f}%)")
    
    print("\n" + "-" * 50)
    print("OVERALL CATEGORY DETECTION METRICS")
    print("-" * 50)
    
    # All components with all metrics
    print("\nMethod 1 - Intersection/Max Accuracy:")
    for component in ['aspect', 'opinion', 'category', 'polarity']:
        print(f"  {component.capitalize():<10} {results['component_metrics'][component]['intersection_max_accuracy']:.4f}")
    
    print("\nMethod 2 - Standard Metrics:")
    print(f"{'Component':<10} {'Precision':<10} {'Recall':<10} {'F1-Score':<10}")
    print("-" * 40)
    for component in ['aspect', 'opinion', 'category', 'polarity']:
        metrics = results['component_metrics'][component]
        print(f"{component.capitalize():<10} {metrics['precision']:<10.4f} {metrics['recall']:<10.4f} {metrics['f1']:<10.4f}")
    
    # Conditional polarity metrics
    print("\nPolarity Accuracy (Conditional):")
    print(f"  When Aspect Matches:   {results['polarity_conditional_metrics']['given_aspect_match']['accuracy']:.4f}")
    print(f"  When Opinion Matches:  {results['polarity_conditional_metrics']['given_opinion_match']['accuracy']:.4f}")
    print(f"  When Category Matches: {results['polarity_conditional_metrics']['given_category_match']['accuracy']:.4f}")
    
    print("\n" + "-" * 50)
    print("DETAILED STATISTICS")
    print("-" * 50)
    
    # Component details
    for component in ['aspect', 'opinion', 'category', 'polarity']:
        metrics = results['component_metrics'][component]
        print(f"\n{component.capitalize()}:")
        print(f"  Total Predicted: {metrics['total_predicted']}")
        print(f"  Total Ground Truth: {metrics['total_ground_truth']}")
        print(f"  Total Correct: {metrics['total_correct']}")
    
    # Conditional polarity details
    print("\nPolarity (Conditional):")
    for match_type, label in [('given_aspect_match', 'Aspect Match'),
                             ('given_opinion_match', 'Opinion Match'),
                             ('given_category_match', 'Category Match')]:
        metrics = results['polarity_conditional_metrics'][match_type]
        print(f"  {label} - Correct: {metrics['correct']}/{metrics['total']}")


def save_detailed_results(results, per_review_results, output_file):
    """Save detailed results to CSV."""
    with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
        fieldnames = [
            'index', 'text', 'text_length', 'has_mixed_polarity', 'ground_truth_categories',
            'aspect_accuracy', 'aspect_correct', 'aspect_pred', 'aspect_true',
            'opinion_accuracy', 'opinion_correct', 'opinion_pred', 'opinion_true',
            'category_accuracy', 'category_correct', 'category_pred', 'category_true',
            'polarity_accuracy', 'polarity_correct', 'polarity_pred', 'polarity_true',
            'polarity_given_aspect_match_correct', 'polarity_given_aspect_match_total',
            'polarity_given_opinion_match_correct', 'polarity_given_opinion_match_total',
            'polarity_given_category_match_correct', 'polarity_given_category_match_total'
        ]
        
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        
        for review in per_review_results:
            writer.writerow(review)


def main():
    """Main function to run the evaluation."""
    # File paths
    predictions_file = "../datasets/DeepSeek-R1-Distill-Qwen-32B_predictions.json"
    ground_truth_file = "../datasets/laptop_quad_test.tsv.jsonl"
    
    print("Loading data files...")
    
    try:
        # Load data
        predictions_data = load_predictions(predictions_file)
        print(f"Loaded {len(predictions_data)} predictions")
        
        ground_truth_list, review_texts = load_ground_truth(ground_truth_file)
        print(f"Loaded {len(ground_truth_list)} ground truth reviews")
        
        # Match by text
        print("\nMatching predictions to ground truth by text...")
        matched_preds, matched_gt, matched_texts = match_by_text(
            predictions_data, ground_truth_list, review_texts
        )
        print(f"Successfully matched {len(matched_preds)} reviews")
        
        if len(matched_preds) == 0:
            print("\nERROR: No matches found!")
            return
        
        # Run evaluation
        print("\nRunning comprehensive evaluation...")
        results, per_review_results = comprehensive_evaluation(
            matched_preds, matched_gt, matched_texts
        )
        
        # Print results
        print_results(results)
        
        # Save results
        csv_file = "absa_evaluation_results.csv"
        save_detailed_results(results, per_review_results, csv_file)
        print(f"\nDetailed results saved to: {csv_file}")
        
        json_file = "absa_evaluation_summary.json"
        with open(json_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2)
        print(f"Summary saved to: {json_file}")
        
    except FileNotFoundError as e:
        print(f"Error: Could not find file - {e}")
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()

Loading data files...
Loaded 100 predictions
Loaded 816 ground truth reviews

Matching predictions to ground truth by text...
Successfully matched 100 reviews

Running comprehensive evaluation...
CATEGORY DETECTION EVALUATION RESULTS

Dataset Overview:
  Total Reviews: 100
  Reviews with Mixed Polarity: 4 (4.0%)

--------------------------------------------------
OVERALL CATEGORY DETECTION METRICS
--------------------------------------------------

Method 1 - Intersection/Max Accuracy:
  Aspect     0.2800
  Opinion    0.4667
  Category   0.1850
  Polarity   0.4600

Method 2 - Standard Metrics:
Component  Precision  Recall     F1-Score  
----------------------------------------
Aspect     0.4000     0.0460     0.0825    
Opinion    0.5455     0.0811     0.1412    
Category   0.3725     0.1743     0.2375    
Polarity   0.8704     0.4434     0.5875    

Polarity Accuracy (Conditional):
  When Aspect Matches:   0.6000
  When Opinion Matches:  0.6667
  When Category Matches: 0.8636

-------