# GriceBench Phase 1: Data Preparation

This notebook generates the required data files for Phase 2 evaluation:
1. Relation Evaluation Set (200 examples)
2. Annotation Sample (1000 examples)

## Required Datasets to Add:

Before running this notebook, add these datasets:

1. **Your GriceBench Data** - Upload as private dataset containing:
   - `repair_data/repair_test.json`
   - `gold_annotation_set.json`
   - `val_examples.json`
   - `topical_corpus.json`

Mount path: `/kaggle/input/gricebench-scientific-fix/`

In [None]:
# Cell 1: Configuration
import os
import json
import random
from pathlib import Path
from collections import defaultdict
import re

# Paths - adjust based on your dataset mount
DATA_INPUT = Path("/kaggle/input/gricebench-scientific-fix")
OUTPUT_DIR = Path("/kaggle/working")

# Verify dataset is mounted
if DATA_INPUT.exists():
    print(f"‚úÖ Dataset mounted at {DATA_INPUT}")
    print("Contents:")
    for item in DATA_INPUT.iterdir():
        print(f"  - {item.name}")
else:
    print("‚ùå Dataset not found! Please add gricebench-scientific-fix dataset.")

## Part 1: Create Relation Evaluation Set

Samples 200 examples with Relation violations for MRR evaluation.

In [None]:
# Cell 2: Create Relation Eval Set

def create_relation_eval_set(test_data_path, output_path, num_examples=200, seed=42):
    """
    Sample 200 examples with Relation violations for MRR evaluation.
    Per morechanges.md lines 746-769.
    """
    random.seed(seed)
    
    print("=" * 70)
    print("CREATE RELATION EVALUATION SET")
    print("=" * 70)
    
    # Load test data
    print(f"\nLoading test data from {test_data_path}...")
    with open(test_data_path, 'r', encoding='utf-8') as f:
        test_data = json.load(f)
    
    print(f"  Total examples: {len(test_data)}")
    
    # Filter for Relation violations
    relation_examples = []
    for i, item in enumerate(test_data):
        input_text = item.get("input_text", "")
        if "[VIOLATION=RELATION]" in input_text:
            example = {
                "id": f"relation_eval_{i}",
                "input_text": input_text,
                "target_text": item.get("target_text", ""),
                "source_index": i
            }
            
            # Extract context and response
            context_match = re.search(r'\[CONTEXT\](.*?)\[', input_text, re.DOTALL)
            response_match = re.search(r'\[RESPONSE\](.*?)$', input_text, re.DOTALL)
            
            if context_match:
                example["context"] = context_match.group(1).strip()
            if response_match:
                example["response"] = response_match.group(1).strip()
            
            relation_examples.append(example)
    
    print(f"  Relation violations found: {len(relation_examples)}")
    
    if len(relation_examples) < num_examples:
        print(f"  WARNING: Only {len(relation_examples)} examples available")
        num_examples = len(relation_examples)
    
    # Sample
    sampled = random.sample(relation_examples, num_examples)
    
    # Save
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(sampled, f, indent=2, ensure_ascii=False)
    
    print(f"\n‚úÖ Saved {len(sampled)} examples to {output_path}")
    return sampled

# Run
repair_test_path = DATA_INPUT / "repair_data" / "repair_test.json"
if not repair_test_path.exists():
    repair_test_path = DATA_INPUT / "repair_test.json"

if repair_test_path.exists():
    relation_eval_set = create_relation_eval_set(
        test_data_path=str(repair_test_path),
        output_path=str(OUTPUT_DIR / "relation_eval_set.json"),
        num_examples=200
    )
else:
    print(f"‚ùå repair_test.json not found at expected paths")
    print(f"   Checked: {repair_test_path}")

## Part 2: Create Annotation Sample (1000 examples)

Creates stratified sample for human annotation per morechanges.md.

In [None]:
# Cell 3: Create Annotation Sample

def create_annotation_sample(val_data_path, gold_data_path, output_path, num_samples=1000, seed=42):
    """
    Create stratified sample for annotation.
    Per morechanges.md lines 658-679.
    """
    random.seed(seed)
    
    print("=" * 70)
    print("CREATE ANNOTATION SAMPLE (1,000 examples)")
    print("=" * 70)
    
    all_examples = []
    
    # Load validation data
    if Path(val_data_path).exists():
        with open(val_data_path, 'r', encoding='utf-8') as f:
            val_data = json.load(f)
        print(f"  Validation: {len(val_data)} examples")
        for i, item in enumerate(val_data):
            item['source_file'] = 'validation'
            item['source_idx'] = i
        all_examples.extend(val_data)
    
    # Load gold data
    if Path(gold_data_path).exists():
        with open(gold_data_path, 'r', encoding='utf-8') as f:
            gold_data = json.load(f)
        print(f"  Gold: {len(gold_data)} examples")
        for i, item in enumerate(gold_data):
            item['source_file'] = 'gold'
            item['source_idx'] = i
        all_examples.extend(gold_data)
    
    print(f"\nTotal pool: {len(all_examples)} examples")
    
    # Categorize
    maxims = ['quantity', 'quality', 'relation', 'manner']
    detector_positives = defaultdict(list)
    detector_negatives = []
    
    for item in all_examples:
        labels = item.get('labels', item.get('detector_predictions', {}))
        has_violation = any(labels.get(m, 0) for m in maxims)
        
        if not has_violation:
            detector_negatives.append(item)
        else:
            for maxim in maxims:
                if labels.get(maxim, 0):
                    detector_positives[maxim].append(item)
    
    print(f"\nCategorization:")
    print(f"  Clean examples: {len(detector_negatives)}")
    for maxim in maxims:
        print(f"  {maxim} positives: {len(detector_positives[maxim])}")
    
    # Sample
    final_sample = []
    seen_ids = set()
    
    def add_samples(pool, count, category):
        nonlocal final_sample, seen_ids
        shuffled = pool.copy()
        random.shuffle(shuffled)
        added = 0
        for item in shuffled:
            item_id = item.get('id', f"{item.get('source_file', 'unknown')}_{item.get('source_idx', 0)}")
            if item_id not in seen_ids:
                item['annotation_category'] = category
                item['sample_id'] = f"sample_{len(final_sample)}"
                final_sample.append(item)
                seen_ids.add(item_id)
                added += 1
                if added >= count:
                    break
        return added
    
    # 200 per maxim
    for maxim in maxims:
        added = add_samples(detector_positives[maxim], 200, f"{maxim}_positive")
        print(f"  Added {added} {maxim} positives")
    
    # 200 clean
    added = add_samples(detector_negatives, 200, "clean")
    print(f"  Added {added} clean examples")
    
    # 100 random
    remaining = [item for item in all_examples 
                 if item.get('id', f"{item.get('source_file', '')}_{item.get('source_idx', 0)}") not in seen_ids]
    added = add_samples(remaining, 100, "random")
    print(f"  Added {added} random examples")
    
    print(f"\nTotal sampled: {len(final_sample)}")
    
    # Shuffle and assign IDs
    random.shuffle(final_sample)
    for i, item in enumerate(final_sample):
        item['id'] = f"annotation_{i:04d}"
    
    # Save
    output_path = Path(output_path)
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(final_sample, f, indent=2, ensure_ascii=False)
    
    print(f"\n‚úÖ Saved {len(final_sample)} examples to {output_path}")
    return final_sample

# Run
val_path = DATA_INPUT / "val_examples.json"
gold_path = DATA_INPUT / "gold_annotation_set.json"

annotation_sample = create_annotation_sample(
    val_data_path=str(val_path),
    gold_data_path=str(gold_path),
    output_path=str(OUTPUT_DIR / "annotation_sample_1000.json"),
    num_samples=1000
)

## Part 3: Verify Outputs

In [None]:
# Cell 4: Verify and download outputs

print("=" * 70)
print("OUTPUT VERIFICATION")
print("=" * 70)

outputs = [
    OUTPUT_DIR / "relation_eval_set.json",
    OUTPUT_DIR / "annotation_sample_1000.json"
]

for output_file in outputs:
    if output_file.exists():
        size_kb = output_file.stat().st_size / 1024
        with open(output_file, 'r') as f:
            data = json.load(f)
        print(f"\n‚úÖ {output_file.name}")
        print(f"   Size: {size_kb:.1f} KB")
        print(f"   Examples: {len(data)}")
        if data:
            print(f"   Sample keys: {list(data[0].keys())[:5]}")
    else:
        print(f"\n‚ùå {output_file.name} - NOT CREATED")

print("\n" + "=" * 70)
print("PHASE 1 COMPLETE")
print("=" * 70)
print("\nüì• Download and add to your dataset for Phase 2:")
print("   1. relation_eval_set.json")
print("   2. annotation_sample_1000.json")