# GriceBench Phase 3: Detector-Human Agreement Analysis

## What This Notebook Does

1. **Loads annotation sample** with detector labels
2. **Calculates agreement metrics** between detector predictions and ground truth
3. **Cohen's Kappa** for each maxim (Quantity, Quality, Relation, Manner)
4. **Error analysis** - which violations is the detector missing?
5. **Decision point** - determine if detector needs retraining

---

## Required Dataset: `gricebench-scientific-fix`

**Files needed (from Phase 2):**

| File | Description |
|------|-------------|
| `annotation_sample_1000.json` | 600+ examples with detector labels |
| `relation_repair_mrr.json` | MRR results (for reference) |

**Local paths (add these to your Kaggle dataset):**
- `c:\Users\pushk\OneDrive\Documents\Research Model\GriceBench\results\phase2output\annotation_sample_1000.json`
- `c:\Users\pushk\OneDrive\Documents\Research Model\GriceBench\results\phase2output\relation_repair_mrr.json`

In [None]:
# ============================================================================
# CELL 1: IMPORTS AND CONFIGURATION
# ============================================================================

import json
import numpy as np
from pathlib import Path
from collections import defaultdict
from typing import Dict, List

# Paths
DATA_INPUT = Path("/kaggle/input/gricebench-scientific-fix")
OUTPUT_DIR = Path("/kaggle/working")

print("Configuration:")
print(f"  Input: {DATA_INPUT}")
print(f"  Output: {OUTPUT_DIR}")

In [None]:
# ============================================================================
# CELL 2: LOAD ANNOTATION SAMPLE
# ============================================================================

print("=" * 70)
print("LOADING ANNOTATION SAMPLE")
print("=" * 70)

# Find annotation sample
annotation_path = DATA_INPUT / "annotation_sample_1000.json"
if not annotation_path.exists():
    # Try phase2output subfolder
    annotation_path = DATA_INPUT / "phase2output" / "annotation_sample_1000.json"

if annotation_path.exists():
    with open(annotation_path, 'r', encoding='utf-8') as f:
        annotations = json.load(f)
    print(f"‚úÖ Loaded {len(annotations)} examples")
else:
    print("‚ùå annotation_sample_1000.json not found!")
    print(f"   Checked: {annotation_path}")
    annotations = []

# Preview structure
if annotations:
    print(f"\nSample keys: {list(annotations[0].keys())}")
    if 'labels' in annotations[0]:
        print(f"Labels structure: {annotations[0]['labels']}")

In [None]:
# ============================================================================
# CELL 3: ANALYZE LABEL DISTRIBUTION
# ============================================================================

print("=" * 70)
print("LABEL DISTRIBUTION ANALYSIS")
print("=" * 70)

maxims = ['quantity', 'quality', 'relation', 'manner']

# Count detector predictions
detector_counts = defaultdict(int)
examples_with_labels = 0
category_counts = defaultdict(int)

for item in annotations:
    labels = item.get('labels', {})
    if labels:
        examples_with_labels += 1
        for maxim in maxims:
            if labels.get(maxim, 0) == 1:
                detector_counts[maxim] += 1
    
    category = item.get('annotation_category', 'unknown')
    category_counts[category] += 1

print(f"\nTotal examples: {len(annotations)}")
print(f"Examples with detector labels: {examples_with_labels}")

print("\nüìä Detector Predictions (positive count):")
for maxim in maxims:
    count = detector_counts[maxim]
    pct = count / len(annotations) * 100 if annotations else 0
    print(f"   {maxim.capitalize()}: {count} ({pct:.1f}%)")

print("\nüìÅ Category Distribution:")
for cat, count in sorted(category_counts.items()):
    print(f"   {cat}: {count}")

In [None]:
# ============================================================================
# CELL 4: COHEN'S KAPPA CALCULATION
# ============================================================================

def calculate_cohen_kappa(y_true: List[int], y_pred: List[int]) -> float:
    """
    Calculate Cohen's Kappa for binary classification.
    
    Œ∫ = (p_o - p_e) / (1 - p_e)
    where p_o = observed agreement, p_e = expected agreement
    """
    if len(y_true) != len(y_pred):
        raise ValueError("Lists must be same length")
    
    n = len(y_true)
    if n == 0:
        return 0.0
    
    # Observed agreement
    p_o = sum(1 for a, b in zip(y_true, y_pred) if a == b) / n
    
    # Expected agreement
    true_pos = sum(y_true) / n
    true_neg = 1 - true_pos
    pred_pos = sum(y_pred) / n
    pred_neg = 1 - pred_pos
    
    p_e = (true_pos * pred_pos) + (true_neg * pred_neg)
    
    if p_e == 1.0:
        return 1.0 if p_o == 1.0 else 0.0
    
    return (p_o - p_e) / (1 - p_e)

print("Cohen's Kappa function defined ‚úÖ")

In [None]:
# ============================================================================
# CELL 5: CALCULATE AGREEMENT USING GROUND TRUTH
# ============================================================================
# 
# The annotation_sample has:
# - 'labels': detector predictions (what we're evaluating)
# - 'violation_type': ground truth of what violation was injected
# 
# We compare detector predictions against the known violation type.

print("=" * 70)
print("DETECTOR-GROUND TRUTH AGREEMENT")
print("=" * 70)

# Map violation types to maxims
violation_to_maxim = {
    'quantity_over': 'quantity',
    'quantity_under': 'quantity',
    'quality_unsupported': 'quality',
    'quality_contradictory': 'quality',
    'relation_off_topic': 'relation',
    'relation_tangential': 'relation',
    'manner_ambiguous': 'manner',
    'manner_jargon': 'manner',
    'manner_shuffled': 'manner',
}

# Build ground truth labels
ground_truth = defaultdict(list)
detector_pred = defaultdict(list)

valid_count = 0
for item in annotations:
    detector_labels = item.get('labels', {})
    violation_type = item.get('violation_type', '')
    category = item.get('annotation_category', '')
    
    # Skip examples without detector labels
    if not detector_labels:
        continue
    
    valid_count += 1
    
    # Determine ground truth for each maxim
    for maxim in maxims:
        # Ground truth: was this maxim actually violated?
        true_label = 0
        
        # Check violation_type
        if violation_type:
            if violation_type.startswith(maxim) or violation_to_maxim.get(violation_type) == maxim:
                true_label = 1
            elif violation_type.startswith('multi_'):
                # Multi-violation, check metadata
                metadata = item.get('metadata', {})
                violated_maxims = metadata.get('maxims_violated', [])
                if maxim in violated_maxims:
                    true_label = 1
        
        # Check category for positive samples
        if category == f"{maxim}_positive":
            true_label = 1
        
        # Detector prediction
        pred_label = int(detector_labels.get(maxim, 0))
        
        ground_truth[maxim].append(true_label)
        detector_pred[maxim].append(pred_label)

print(f"\nValid examples for agreement: {valid_count}")

In [None]:
# ============================================================================
# CELL 6: CALCULATE AND DISPLAY RESULTS
# ============================================================================

print("\n" + "=" * 70)
print("AGREEMENT RESULTS")
print("=" * 70)

results = {}

for maxim in maxims:
    y_true = ground_truth[maxim]
    y_pred = detector_pred[maxim]
    
    if not y_true:
        continue
    
    # Calculate metrics
    kappa = calculate_cohen_kappa(y_true, y_pred)
    
    tp = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 1)
    tn = sum(1 for t, p in zip(y_true, y_pred) if t == 0 and p == 0)
    fp = sum(1 for t, p in zip(y_true, y_pred) if t == 0 and p == 1)
    fn = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 0)
    
    accuracy = (tp + tn) / len(y_true)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    results[maxim] = {
        'kappa': round(kappa, 4),
        'accuracy': round(accuracy, 4),
        'precision': round(precision, 4),
        'recall': round(recall, 4),
        'f1': round(f1, 4),
        'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
    }
    
    print(f"\n{maxim.upper()}:")
    print(f"   Cohen's Œ∫:  {kappa:.4f}")
    print(f"   Accuracy:   {accuracy:.4f}")
    print(f"   Precision:  {precision:.4f}")
    print(f"   Recall:     {recall:.4f}")
    print(f"   F1:         {f1:.4f}")
    print(f"   Confusion:  TP={tp}, TN={tn}, FP={fp}, FN={fn}")

In [None]:
# ============================================================================
# CELL 7: OVERALL KAPPA AND VERDICT
# ============================================================================

print("\n" + "=" * 70)
print("OVERALL AGREEMENT AND VERDICT")
print("=" * 70)

# Calculate overall kappa (across all maxims)
all_true = []
all_pred = []
for maxim in maxims:
    all_true.extend(ground_truth[maxim])
    all_pred.extend(detector_pred[maxim])

overall_kappa = calculate_cohen_kappa(all_true, all_pred)

# Mean kappa across maxims
kappa_values = [results[m]['kappa'] for m in results]
mean_kappa = np.mean(kappa_values) if kappa_values else 0

print(f"\nOverall Cohen's Œ∫: {overall_kappa:.4f}")
print(f"Mean Œ∫ (per maxim): {mean_kappa:.4f}")

# Interpretation
print("\n" + "-" * 50)
print("Œ∫ INTERPRETATION:")
print("   Œ∫ < 0.2:  Poor agreement")
print("   0.2-0.4:  Fair agreement")
print("   0.4-0.6:  Moderate agreement")
print("   0.6-0.8:  Substantial agreement")
print("   Œ∫ > 0.8:  Almost perfect agreement")

print("\n" + "-" * 50)
print("VERDICT:")

if mean_kappa >= 0.7:
    verdict = "EXCELLENT"
    action = "Detector is well-calibrated. Proceed to Phase 4."
    emoji = "‚úÖ"
elif mean_kappa >= 0.5:
    verdict = "ACCEPTABLE"
    action = "Detector acceptable but could improve. Monitor errors."
    emoji = "‚ö†Ô∏è"
else:
    verdict = "NEEDS RETRAINING"
    action = "Detector needs retraining on human annotations."
    emoji = "‚ùå"

print(f"\n{emoji} {verdict}")
print(f"\nRecommendation: {action}")

In [None]:
# ============================================================================
# CELL 8: ERROR ANALYSIS
# ============================================================================

print("\n" + "=" * 70)
print("ERROR ANALYSIS")
print("=" * 70)

# Find examples where detector made errors
for maxim in maxims:
    fn_examples = []  # False negatives (detector missed)
    fp_examples = []  # False positives (detector wrong)
    
    for i, item in enumerate(annotations):
        detector_labels = item.get('labels', {})
        if not detector_labels:
            continue
        
        true_label = ground_truth[maxim][i] if i < len(ground_truth[maxim]) else 0
        pred_label = detector_pred[maxim][i] if i < len(detector_pred[maxim]) else 0
        
        if true_label == 1 and pred_label == 0:
            fn_examples.append(item)
        elif true_label == 0 and pred_label == 1:
            fp_examples.append(item)
    
    print(f"\n{maxim.upper()}:")
    print(f"   False Negatives (missed): {len(fn_examples)}")
    print(f"   False Positives (wrong): {len(fp_examples)}")
    
    # Show sample error
    if fn_examples:
        sample = fn_examples[0]
        resp = sample.get('violated_response', sample.get('response', ''))[:100]
        print(f"   Sample FN: {resp}...")

In [None]:
# ============================================================================
# CELL 9: SAVE RESULTS
# ============================================================================

output = {
    'per_maxim': results,
    'overall_kappa': round(overall_kappa, 4),
    'mean_kappa': round(mean_kappa, 4),
    'verdict': verdict,
    'recommendation': action,
    'n_examples': len(annotations)
}

output_path = OUTPUT_DIR / "detector_human_agreement.json"
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2)

print(f"\n‚úÖ Results saved to: {output_path}")

print("\n" + "=" * 70)
print("PHASE 3 COMPLETE")
print("=" * 70)
print(f"\nDownload: detector_human_agreement.json")
print(f"\nNext Steps:")
if mean_kappa >= 0.5:
    print("   ‚Üí Proceed to Phase 4: Natural Violation Collection")
else:
    print("   ‚Üí Go to Phase 6: Retrain detector on annotations")