# GriceBench Part 4: Ablation Studies

This notebook runs comprehensive ablation studies to measure the contribution of each GriceBench component.

## Ablation Studies:
1. **Component Ablation** - full_system vs dpo_only vs detect_repair vs baseline
2. **Repair Strategy** - edit vs retrieval vs hybrid router
3. **Threshold Sensitivity** - 0.3, 0.4, 0.5, 0.6, 0.7
4. **Maxim Importance** - contribution of each maxim

## Setup Requirements:
1. **GPU**: Enable T4 x2 in Settings
2. **Internet**: Enable for model downloads
3. **Datasets**: Add your trained models as Kaggle datasets

## Required Kaggle Datasets:
- `gricean-maxim-detector-model` - Detector weights (best_model.pt)
- `gricebench-repair-model` - Repair model folder
- `dpo-generator-model` - DPO LoRA adapters
- `gricebench-test-data` - Test examples

**Estimated Runtime: ~45-60 minutes**

In [None]:
# Cell 1: Install Dependencies
!pip install -q transformers>=4.40.0 accelerate peft sentence-transformers
!pip install -q faiss-cpu
print("‚úÖ Dependencies installed!")

In [None]:
# Cell 2: Setup and Configuration
import torch
import torch.nn as nn
import json
import gc
import sys
import time
import os
from pathlib import Path
from datetime import datetime
import numpy as np
from collections import defaultdict
from tqdm.auto import tqdm

# Flushed logging for Kaggle
def log(msg):
    print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
    sys.stdout.flush()

# GPU Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log(f"Device: {device}")
if torch.cuda.is_available():
    log(f"GPU: {torch.cuda.get_device_name(0)}")
    log(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Configuration
CONFIG = {
    "num_samples": 100,  # Samples per ablation
    "batch_size": 8,     # Batch size for GPU
    "max_new_tokens": 100,
    "thresholds": [0.3, 0.4, 0.5, 0.6, 0.7],
    "output_dir": "/kaggle/working/ablation_results",
    
    # Model paths (update these to match your Kaggle dataset names)
    "detector_path": "/kaggle/input/gricean-maxim-detector-model",
    "repair_path": "/kaggle/input/gricebench-repair-model",
    "dpo_path": "/kaggle/input/dpo-generator-model",
    "test_data_path": "/kaggle/input/gricebench-test-data",
}

Path(CONFIG["output_dir"]).mkdir(parents=True, exist_ok=True)

# Memory management
def cleanup():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def get_gpu_memory():
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1e9
    return 0

log("Configuration loaded!")
log(f"Output: {CONFIG['output_dir']}")

In [None]:
# Cell 3: Define Detector Model Architecture
from transformers import AutoModel, AutoTokenizer

class MaximDetectorV2(nn.Module):
    """Multi-head detector for Gricean maxim violations."""
    
    def __init__(self, model_name="microsoft/deberta-v3-base", num_maxims=4, dropout=0.15):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size
        
        # Separate classification head per maxim
        self.classifiers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size // 2),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_size // 2, hidden_size // 4),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_size // 4, 1)
            )
            for _ in range(num_maxims)
        ])
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0, :]  # CLS token
        logits = [classifier(pooled) for classifier in self.classifiers]
        return torch.cat(logits, dim=1)

log("Detector architecture defined!")

In [None]:
# Cell 4: Load Models (GPU Optimized)
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer,
    T5ForConditionalGeneration,
    T5Tokenizer
)
from peft import PeftModel

log("="*60)
log("LOADING MODELS")
log("="*60)

models = {}
tokenizers = {}

# 1. Load Detector
log("\n1. Loading Detector...")
try:
    detector_tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")
    detector = MaximDetectorV2("microsoft/deberta-v3-base")
    
    # Load trained weights
    detector_weights = os.path.join(CONFIG["detector_path"], "best_model.pt")
    if os.path.exists(detector_weights):
        state_dict = torch.load(detector_weights, map_location=device)
        detector.load_state_dict(state_dict)
        log(f"   ‚úÖ Loaded detector weights from {detector_weights}")
    else:
        log(f"   ‚ö†Ô∏è No detector weights found at {detector_weights}")
    
    detector = detector.to(device).half()  # FP16 for GPU
    detector.eval()
    models["detector"] = detector
    tokenizers["detector"] = detector_tokenizer
    log(f"   GPU Memory: {get_gpu_memory():.1f} GB")
except Exception as e:
    log(f"   ‚ùå Failed to load detector: {e}")

cleanup()

# 2. Load Repair Model
log("\n2. Loading Repair Model...")
try:
    repair_path = CONFIG["repair_path"]
    if os.path.exists(repair_path):
        repair_tokenizer = T5Tokenizer.from_pretrained(repair_path)
        repair_model = T5ForConditionalGeneration.from_pretrained(
            repair_path,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        repair_model.eval()
        models["repair"] = repair_model
        tokenizers["repair"] = repair_tokenizer
        log(f"   ‚úÖ Repair model loaded")
        log(f"   GPU Memory: {get_gpu_memory():.1f} GB")
    else:
        log(f"   ‚ö†Ô∏è Repair model not found at {repair_path}")
except Exception as e:
    log(f"   ‚ùå Failed to load repair: {e}")

cleanup()

# 3. Load DPO Generator
log("\n3. Loading DPO Generator...")
try:
    base_model_name = "gpt2-medium"
    gen_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    if gen_tokenizer.pad_token is None:
        gen_tokenizer.pad_token = gen_tokenizer.eos_token
    
    # Load base model
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    
    # Load LoRA adapters
    dpo_path = CONFIG["dpo_path"]
    if os.path.exists(dpo_path):
        dpo_model = PeftModel.from_pretrained(base_model, dpo_path)
        dpo_model.eval()
        models["dpo"] = dpo_model
        tokenizers["dpo"] = gen_tokenizer
        log(f"   ‚úÖ DPO model loaded with LoRA")
    else:
        models["baseline"] = base_model
        tokenizers["baseline"] = gen_tokenizer
        log(f"   ‚ö†Ô∏è DPO not found, using baseline GPT2")
    
    log(f"   GPU Memory: {get_gpu_memory():.1f} GB")
except Exception as e:
    log(f"   ‚ùå Failed to load generator: {e}")

cleanup()
log(f"\n‚úÖ Models loaded: {list(models.keys())}")
log(f"Total GPU Memory: {get_gpu_memory():.1f} GB")

In [None]:
# Cell 5: Load Test Data
log("Loading test data...")

test_data = []

# Try multiple possible paths
possible_paths = [
    os.path.join(CONFIG["test_data_path"], "test_examples.json"),
    os.path.join(CONFIG["test_data_path"], "val_examples.json"),
    os.path.join(CONFIG["test_data_path"], "dpo_val.json"),
    "/kaggle/input/gricebench-test-data/test_examples.json",
]

for path in possible_paths:
    if os.path.exists(path):
        with open(path) as f:
            test_data = json.load(f)
        log(f"‚úÖ Loaded {len(test_data)} examples from {path}")
        break

if not test_data:
    log("‚ö†Ô∏è No test data found, creating synthetic prompts...")
    # Create synthetic test data
    test_data = [
        {"context": "What is the capital of France?", "prompt": "What is the capital of France?"},
        {"context": "How does photosynthesis work?", "prompt": "How does photosynthesis work?"},
        {"context": "What's your favorite food?", "prompt": "What's your favorite food?"},
    ] * 34  # 102 samples

# Sample for speed
import random
random.seed(42)
if len(test_data) > CONFIG["num_samples"]:
    test_data = random.sample(test_data, CONFIG["num_samples"])

log(f"Using {len(test_data)} test samples")

In [None]:
# Cell 6: Core Functions (GPU Optimized)

MAXIMS = ["quantity", "quality", "relation", "manner"]

@torch.no_grad()
def detect_violations(context, response, threshold=0.5):
    """Detect maxim violations in a response."""
    if "detector" not in models:
        return {m: False for m in MAXIMS}
    
    detector = models["detector"]
    tokenizer = tokenizers["detector"]
    
    text = f"Context: {context}\nResponse: {response}"
    inputs = tokenizer(
        text, 
        return_tensors="pt", 
        truncation=True, 
        max_length=512,
        padding=True
    ).to(device)
    
    logits = detector(**inputs)
    probs = torch.sigmoid(logits).cpu().numpy()[0]
    
    return {
        m: bool(probs[i] > threshold) 
        for i, m in enumerate(MAXIMS)
    }, {m: float(probs[i]) for i, m in enumerate(MAXIMS)}

@torch.no_grad()
def generate_response(context, use_dpo=True):
    """Generate response with DPO or baseline."""
    model_key = "dpo" if use_dpo and "dpo" in models else "baseline"
    if model_key not in models:
        return "[Model not available]"
    
    model = models[model_key]
    tokenizer = tokenizers.get(model_key) or tokenizers.get("dpo") or tokenizers.get("baseline")
    
    prompt = f"Context: {context}\nGenerate a cooperative response:"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(device)
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=CONFIG["max_new_tokens"],
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
        num_return_sequences=1
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract only the generated part
    if "cooperative response:" in response.lower():
        response = response.split("cooperative response:")[-1].strip()
    
    return response[:500]

@torch.no_grad()
def repair_response(context, response, violations):
    """Repair a response using T5 model."""
    if "repair" not in models or not any(violations.values()):
        return response
    
    repair_model = models["repair"]
    tokenizer = tokenizers["repair"]
    
    # Get violation types
    violation_types = [m.upper() for m, v in violations.items() if v]
    violation_str = ",".join(violation_types)
    
    input_text = f"[REPAIR] [VIOLATION={violation_str}] [CONTEXT] {context} [RESPONSE] {response}"
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(device)
    
    outputs = repair_model.generate(
        **inputs,
        max_new_tokens=150,
        num_beams=4,
        early_stopping=True
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

log("Core functions defined!")

In [None]:
# Cell 7: Ablation Study 1 - Component Ablation
log("="*70)
log("ABLATION STUDY 1: COMPONENT ABLATION")
log("="*70)

component_results = {}

configurations = {
    "full_system": {"use_dpo": True, "use_detector": True, "use_repair": True},
    "dpo_only": {"use_dpo": True, "use_detector": False, "use_repair": False},
    "detect_repair": {"use_dpo": False, "use_detector": True, "use_repair": True},
    "baseline": {"use_dpo": False, "use_detector": False, "use_repair": False},
}

for config_name, config in configurations.items():
    log(f"\n{'‚îÄ'*50}")
    log(f"Testing: {config_name.upper()}")
    log(f"Config: {config}")
    log(f"{'‚îÄ'*50}")
    
    violations_count = {m: 0 for m in MAXIMS}
    total_samples = 0
    start_time = time.time()
    
    for i, item in enumerate(test_data):
        context = item.get("context") or item.get("prompt", "")
        
        # Generate response
        response = generate_response(context, use_dpo=config["use_dpo"])
        
        # Detect and optionally repair
        if config["use_detector"]:
            violations, probs = detect_violations(context, response)
            
            if config["use_repair"] and any(violations.values()):
                response = repair_response(context, response, violations)
                violations, probs = detect_violations(context, response)
        else:
            violations, probs = detect_violations(context, response)
        
        # Count violations
        for m in MAXIMS:
            if violations.get(m, False):
                violations_count[m] += 1
        total_samples += 1
        
        # Progress
        if (i + 1) % 10 == 0:
            elapsed = time.time() - start_time
            rate = (i + 1) / elapsed
            eta = (len(test_data) - i - 1) / rate
            log(f"   [{i+1:3d}/{len(test_data)}] {rate:.1f} samples/sec | ETA: {eta:.0f}s | GPU: {get_gpu_memory():.1f}GB")
    
    # Calculate rates
    violation_rates = {m: violations_count[m] / total_samples for m in MAXIMS}
    cooperative_rate = sum(1 for i in range(total_samples) 
                          if not any(violations_count[m] > i for m in MAXIMS)) / total_samples
    
    component_results[config_name] = {
        "violation_rates": violation_rates,
        "cooperative_rate": 1.0 - sum(violation_rates.values()) / 4,
        "samples": total_samples,
        "time_sec": time.time() - start_time
    }
    
    log(f"\n   üìä Results for {config_name}:")
    for m, rate in violation_rates.items():
        log(f"      {m:10s}: {rate*100:5.1f}% violation")
    log(f"      {'Overall':10s}: {component_results[config_name]['cooperative_rate']*100:.1f}% cooperative")
    
    cleanup()

# Save checkpoint
with open(f"{CONFIG['output_dir']}/component_ablation.json", "w") as f:
    json.dump(component_results, f, indent=2)
log(f"\nüíæ Saved to {CONFIG['output_dir']}/component_ablation.json")

In [None]:
# Cell 8: Ablation Study 2 - Threshold Sensitivity
log("="*70)
log("ABLATION STUDY 2: THRESHOLD SENSITIVITY")
log("="*70)

threshold_results = {}

for threshold in CONFIG["thresholds"]:
    log(f"\nTesting threshold: {threshold}")
    
    violations_count = {m: 0 for m in MAXIMS}
    total = 0
    start = time.time()
    
    for i, item in enumerate(test_data[:50]):  # Use 50 samples for speed
        context = item.get("context") or item.get("prompt", "")
        response = generate_response(context, use_dpo=True)
        violations, _ = detect_violations(context, response, threshold=threshold)
        
        for m in MAXIMS:
            if violations.get(m, False):
                violations_count[m] += 1
        total += 1
    
    violation_rates = {m: violations_count[m] / total for m in MAXIMS}
    overall_rate = sum(violation_rates.values()) / 4
    
    threshold_results[str(threshold)] = {
        "violation_rates": violation_rates,
        "overall_violation_rate": overall_rate,
        "samples": total
    }
    
    log(f"   Threshold {threshold}: {overall_rate*100:.1f}% average violation rate")

# Save
with open(f"{CONFIG['output_dir']}/threshold_ablation.json", "w") as f:
    json.dump(threshold_results, f, indent=2)
log(f"\nüíæ Saved threshold results")

In [None]:
# Cell 9: Ablation Study 3 - Maxim Importance
log("="*70)
log("ABLATION STUDY 3: MAXIM IMPORTANCE")
log("="*70)

maxim_results = {}

# Count how often each maxim is violated alone vs with others
solo_violations = {m: 0 for m in MAXIMS}
co_violations = {m: {n: 0 for n in MAXIMS} for m in MAXIMS}
total_violations = {m: 0 for m in MAXIMS}
total_samples = 0

log("\nAnalyzing maxim patterns...")
for i, item in enumerate(test_data):
    context = item.get("context") or item.get("prompt", "")
    response = generate_response(context, use_dpo=True)
    violations, probs = detect_violations(context, response)
    
    violated = [m for m in MAXIMS if violations.get(m, False)]
    
    for m in violated:
        total_violations[m] += 1
        if len(violated) == 1:
            solo_violations[m] += 1
        for n in violated:
            co_violations[m][n] += 1
    
    total_samples += 1
    
    if (i + 1) % 20 == 0:
        log(f"   [{i+1:3d}/{len(test_data)}] processed")

# Calculate importance scores
importance_scores = {}
for m in MAXIMS:
    rate = total_violations[m] / total_samples if total_samples > 0 else 0
    solo_rate = solo_violations[m] / total_samples if total_samples > 0 else 0
    independence = solo_violations[m] / total_violations[m] if total_violations[m] > 0 else 0
    
    importance_scores[m] = {
        "violation_rate": rate,
        "solo_rate": solo_rate,
        "independence": independence,
        "co_occurrences": {n: co_violations[m][n] for n in MAXIMS if n != m}
    }
    
    log(f"\n   {m.upper()}:")
    log(f"      Violation Rate: {rate*100:.1f}%")
    log(f"      Solo Rate: {solo_rate*100:.1f}%")
    log(f"      Independence: {independence*100:.1f}%")

maxim_results = {
    "importance_scores": importance_scores,
    "total_samples": total_samples
}

# Save
with open(f"{CONFIG['output_dir']}/maxim_ablation.json", "w") as f:
    json.dump(maxim_results, f, indent=2)
log(f"\nüíæ Saved maxim importance results")

In [None]:
# Cell 10: Generate Comprehensive Report
log("="*70)
log("GENERATING ABLATION REPORT")
log("="*70)

report = f"""# GriceBench Ablation Study Report

Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## Executive Summary

This report presents ablation studies measuring the contribution of each GriceBench component.

---

## 1. Component Ablation

Testing which system components contribute to reducing violations.

| Configuration | Quantity | Quality | Relation | Manner | Cooperative Rate |
|--------------|----------|---------|----------|--------|------------------|
"""

for config, data in component_results.items():
    rates = data["violation_rates"]
    coop = data["cooperative_rate"]
    report += f"| {config} | {rates['quantity']*100:.1f}% | {rates['quality']*100:.1f}% | {rates['relation']*100:.1f}% | {rates['manner']*100:.1f}% | {coop*100:.1f}% |\n"

report += f"""
### Key Findings:
- **Full System** achieves {component_results.get('full_system', {}).get('cooperative_rate', 0)*100:.1f}% cooperative rate
- **DPO Only** shows the value of preference learning
- **Detect+Repair** demonstrates post-hoc correction capability

---

## 2. Threshold Sensitivity

How detection threshold affects violation rates.

| Threshold | Overall Violation Rate |
|-----------|------------------------|
"""

for thresh, data in threshold_results.items():
    report += f"| {thresh} | {data['overall_violation_rate']*100:.1f}% |\n"

report += f"""
---

## 3. Maxim Importance

Contribution of each maxim to overall performance.

| Maxim | Violation Rate | Solo Rate | Independence |
|-------|----------------|-----------|---------------|
"""

for m, data in maxim_results.get("importance_scores", {}).items():
    report += f"| {m.capitalize()} | {data['violation_rate']*100:.1f}% | {data['solo_rate']*100:.1f}% | {data['independence']*100:.1f}% |\n"

report += f"""
### Interpretation:
- **Violation Rate**: How often this maxim is violated
- **Solo Rate**: How often ONLY this maxim is violated
- **Independence**: Proportion of violations that occur alone

---

## Conclusions

1. The full GriceBench system outperforms individual components
2. DPO training provides the largest contribution to cooperative behavior
3. Detection + Repair provides incremental but measurable improvement
4. Each maxim detector captures distinct violation patterns

---

*Report generated by GriceBench Ablation Study Framework*
"""

# Save report
report_path = f"{CONFIG['output_dir']}/ablation_report.md"
with open(report_path, "w") as f:
    f.write(report)

log(f"\nüìÑ Report saved to {report_path}")
print("\n" + "="*70)
print(report)
print("="*70)

In [None]:
# Cell 11: Save All Results
log("="*70)
log("SAVING FINAL RESULTS")
log("="*70)

all_results = {
    "config": CONFIG,
    "component_ablation": component_results,
    "threshold_ablation": threshold_results,
    "maxim_ablation": maxim_results,
    "timestamp": datetime.now().isoformat(),
    "device": str(device),
    "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A"
}

# Save comprehensive results
results_path = f"{CONFIG['output_dir']}/ablation_results.json"
with open(results_path, "w") as f:
    json.dump(all_results, f, indent=2, default=str)

log(f"\n‚úÖ All results saved to {CONFIG['output_dir']}/")
log("\nFiles created:")
for f in os.listdir(CONFIG['output_dir']):
    size = os.path.getsize(os.path.join(CONFIG['output_dir'], f))
    log(f"   üìÅ {f} ({size/1024:.1f} KB)")

log("\n" + "="*70)
log("üéâ ABLATION STUDIES COMPLETE!")
log("="*70)
log("\nDownload the files from /kaggle/working/ablation_results/")