# GriceBench Improvements: Parts 1 & 2

This notebook implements:
- **Part 1**: Fixing Relation Repair (Retrieval-Augmented Generation)
- **Part 2**: Human Evaluation Framework Setup

## How to Use:
1. Upload this notebook to Kaggle
2. Enable GPU (Settings ‚Üí Accelerator ‚Üí GPU T4 x2)
3. Enable Internet (Settings ‚Üí Internet ‚Üí On)
4. Run All Cells
5. Download outputs from `/kaggle/working/`

**Estimated Runtime**: ~15-20 minutes

---
## Cell 1: Install Dependencies

In [None]:
%%time
print("Installing dependencies...")
!pip install -q sentence-transformers faiss-cpu datasets gradio krippendorff
print("‚úÖ Dependencies installed!")

---
## Cell 2: Create Output Directories

In [None]:
import os
from pathlib import Path

# Create directories
dirs = [
    '/kaggle/working/data_processed',
    '/kaggle/working/human_eval_results',
    '/kaggle/working/results/relation_repair_evaluation',
    '/kaggle/working/reports'
]

for d in dirs:
    Path(d).mkdir(parents=True, exist_ok=True)
    print(f"‚úÖ Created: {d}")

os.chdir('/kaggle/working')
print(f"\nWorking directory: {os.getcwd()}")

---
# PART 1: Fixing Relation Repair Problem

The current repair model only achieves 9.3% BLEU on Relation violations because it tries to "edit" off-topic text into on-topic text - an impossible task.

**Solution**: Use retrieval-augmented generation instead of editing.

## Cell 3: Part 1 Step 1 - Create Response Corpus

In [None]:
%%time
"""
Part 1, Step 1: Create Topical Response Corpus
Downloads free dialogue datasets and organizes by topic.
"""

import json
import re
from collections import defaultdict
import hashlib

# Topic taxonomy with keywords
TOPIC_TAXONOMY = {
    "weather": ["weather", "rain", "sunny", "cold", "hot", "temperature", "storm"],
    "food": ["food", "eat", "restaurant", "cook", "meal", "dinner", "lunch", "breakfast"],
    "work": ["work", "job", "office", "boss", "meeting", "project", "career"],
    "family": ["family", "mother", "father", "sister", "brother", "parents", "kids"],
    "travel": ["travel", "trip", "vacation", "flight", "hotel", "visit", "beach"],
    "health": ["health", "doctor", "sick", "medicine", "hospital", "exercise"],
    "entertainment": ["movie", "film", "music", "game", "show", "concert", "book"],
    "sports": ["sport", "team", "play", "win", "match", "football", "basketball"],
    "education": ["school", "study", "learn", "class", "teacher", "student"],
    "technology": ["computer", "phone", "internet", "app", "software", "tech"],
    "pets": ["pet", "dog", "cat", "animal", "puppy"],
    "hobbies": ["hobby", "art", "painting", "music", "garden"],
    "shopping": ["shop", "buy", "store", "price", "sale"],
    "relationship": ["friend", "relationship", "date", "love", "partner"],
}

def extract_topic(text):
    """Extract topic from text using keyword matching."""
    if not text:
        return "general"
    text_lower = text.lower()
    for topic, keywords in TOPIC_TAXONOMY.items():
        for keyword in keywords:
            if keyword in text_lower:
                return topic
    return "general"

def clean_text(text):
    """Clean and normalize text."""
    if not text:
        return ""
    text = re.sub(r'\s+', ' ', text).strip()
    text = re.sub(r'\[.*?\]', '', text)
    return text.strip()

def is_quality_response(response, min_len=10, max_len=150):
    """Check if response meets quality criteria."""
    words = response.split()
    if len(words) < min_len or len(words) > max_len:
        return False
    if response.count('?') > 2:
        return False
    return True

# Load datasets
print("Loading dialogue datasets from HuggingFace...")
from datasets import load_dataset

corpus = defaultdict(list)
total_examples = 0

# Dataset 1: Daily Dialog
print("\nüì• Loading daily_dialog...")
try:
    dd = load_dataset("daily_dialog", split="train", trust_remote_code=True)
    for item in dd:
        dialog = item.get("dialog", [])
        for i in range(1, len(dialog)):
            context = clean_text(dialog[i-1])
            response = clean_text(dialog[i])
            if context and response and is_quality_response(response):
                topic = extract_topic(f"{context} {response}")
                corpus[topic].append({"context": context, "response": response, "source": "daily_dialog"})
                total_examples += 1
    print(f"   ‚úÖ Extracted {total_examples} examples")
except Exception as e:
    print(f"   ‚ö†Ô∏è Could not load: {e}")

# Dataset 2: Empathetic Dialogues
print("\nüì• Loading empathetic_dialogues...")
prev_total = total_examples
try:
    ed = load_dataset("empathetic_dialogues", split="train", trust_remote_code=True)
    convos = defaultdict(list)
    for item in ed:
        conv_id = item.get("conv_id", "")
        utterance = clean_text(item.get("utterance", ""))
        if conv_id and utterance:
            convos[conv_id].append(utterance)
    for conv_id, utts in convos.items():
        for i in range(1, len(utts)):
            context, response = utts[i-1], utts[i]
            if context and response and is_quality_response(response):
                topic = extract_topic(f"{context} {response}")
                corpus[topic].append({"context": context, "response": response, "source": "empathetic_dialogues"})
                total_examples += 1
    print(f"   ‚úÖ Extracted {total_examples - prev_total} examples")
except Exception as e:
    print(f"   ‚ö†Ô∏è Could not load: {e}")

# Dataset 3: Blended Skill Talk
print("\nüì• Loading blended_skill_talk...")
prev_total = total_examples
try:
    bst = load_dataset("blended_skill_talk", split="train", trust_remote_code=True)
    for item in bst:
        msgs = list(item.get("previous_utterance", [])) + list(item.get("free_messages", []))
        for i in range(1, len(msgs)):
            context = clean_text(msgs[i-1])
            response = clean_text(msgs[i])
            if context and response and is_quality_response(response):
                topic = extract_topic(f"{context} {response}")
                corpus[topic].append({"context": context, "response": response, "source": "blended_skill_talk"})
                total_examples += 1
    print(f"   ‚úÖ Extracted {total_examples - prev_total} examples")
except Exception as e:
    print(f"   ‚ö†Ô∏è Could not load: {e}")

# Balance corpus (max 5000 per topic)
print("\n‚öñÔ∏è Balancing corpus...")
for topic in corpus:
    if len(corpus[topic]) > 5000:
        corpus[topic] = corpus[topic][:5000]

# Save
corpus_path = '/kaggle/working/data_processed/topical_corpus.json'
with open(corpus_path, 'w', encoding='utf-8') as f:
    json.dump(dict(corpus), f, ensure_ascii=False, indent=2)

print("\n" + "="*60)
print("CORPUS CREATION COMPLETE!")
print("="*60)
print(f"\nTotal examples: {sum(len(v) for v in corpus.values()):,}")
print(f"Topics: {len(corpus)}")
print(f"\nSaved to: {corpus_path}")
print("\nPer-topic counts:")
for topic in sorted(corpus.keys(), key=lambda x: len(corpus[x]), reverse=True):
    print(f"  {topic}: {len(corpus[topic]):,}")

## Cell 4: Part 1 Step 2 - Build FAISS Retrieval System

In [None]:
%%time
"""
Part 1, Step 2: Build FAISS Retrieval System
Creates vector index for semantic search.
"""

import json
import numpy as np
import pickle
from sentence_transformers import SentenceTransformer
import faiss

print("Loading corpus...")
with open('/kaggle/working/data_processed/topical_corpus.json', 'r') as f:
    corpus = json.load(f)

# Flatten corpus
all_responses = []
response_metadata = []

for topic, responses in corpus.items():
    for resp in responses:
        all_responses.append(resp["response"])
        response_metadata.append({
            "topic": topic,
            "context": resp["context"],
            "response": resp["response"],
            "source": resp.get("source", "unknown")
        })

print(f"Total responses to index: {len(all_responses):,}")

# Load encoder
print("\nLoading sentence encoder (all-MiniLM-L6-v2)...")
encoder = SentenceTransformer('all-MiniLM-L6-v2')
print(f"Embedding dimension: {encoder.get_sentence_embedding_dimension()}")

# Encode all responses
print("\nEncoding responses (this takes a few minutes)...")
embeddings = encoder.encode(
    all_responses,
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True  # For cosine similarity
)

# Create FAISS index
print("\nBuilding FAISS index...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)  # Inner product = cosine after normalization
index.add(embeddings.astype(np.float32))
print(f"Index built: {index.ntotal} vectors, {dimension}D")

# Save index and metadata
index_path = '/kaggle/working/data_processed/faiss_index.pkl'

# Serialize FAISS index
writer = faiss.VectorIOWriter()
faiss.write_index(index, writer)
index_bytes = writer.get_bytes()

save_data = {
    "index_bytes": index_bytes,
    "response_metadata": response_metadata,
    "all_responses": all_responses
}

with open(index_path, 'wb') as f:
    pickle.dump(save_data, f)

print(f"\n‚úÖ Index saved to: {index_path}")
print(f"   File size: {Path(index_path).stat().st_size / 1024 / 1024:.2f} MB")

## Cell 5: Part 1 Step 3 - Test Retrieval System

In [None]:
"""
Part 1, Step 3: Test the Retrieval System
Demonstrates that Relation repair now works!
"""

class RelationRepairRetriever:
    """Retrieval-based Relation repair."""
    
    def __init__(self, index_path='/kaggle/working/data_processed/faiss_index.pkl'):
        from sentence_transformers import SentenceTransformer
        import faiss
        import pickle
        
        print("Loading retrieval system...")
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
        
        with open(index_path, 'rb') as f:
            data = pickle.load(f)
        
        reader = faiss.VectorIOReader()
        reader.set_bytes(data["index_bytes"])
        self.index = faiss.read_index(reader)
        self.metadata = data["response_metadata"]
        print(f"Loaded {self.index.ntotal} vectors")
    
    def repair_relation_violation(self, context, violated_response, k=5):
        """Find relevant response for the context."""
        # Encode context
        query = self.encoder.encode([context], normalize_embeddings=True).astype(np.float32)
        
        # Search
        distances, indices = self.index.search(query, k)
        
        if indices[0][0] == -1:
            return violated_response
        
        return self.metadata[indices[0][0]]["response"]
    
    def get_relevance_score(self, context, response):
        """Calculate semantic similarity."""
        embeddings = self.encoder.encode([context, response], normalize_embeddings=True)
        return float(np.dot(embeddings[0], embeddings[1]))

# Test!
print("\n" + "="*60)
print("TESTING RELATION REPAIR")
print("="*60)

retriever = RelationRepairRetriever()

test_cases = [
    ("What is your favorite food?", "The stock market closed up 2% yesterday."),
    ("Do you have any pets?", "I think the weather will be nice tomorrow."),
    ("How was your weekend?", "The capital of France is Paris."),
]

for i, (context, violated) in enumerate(test_cases, 1):
    print(f"\n--- Test Case {i} ---")
    print(f"Context: {context}")
    print(f"Violated (off-topic): {violated}")
    
    orig_score = retriever.get_relevance_score(context, violated)
    repaired = retriever.repair_relation_violation(context, violated)
    new_score = retriever.get_relevance_score(context, repaired)
    
    print(f"Repaired (on-topic): {repaired}")
    print(f"Relevance: {orig_score:.3f} ‚Üí {new_score:.3f} (Œî{new_score-orig_score:+.3f})")

print("\n" + "="*60)
print("‚úÖ PART 1 COMPLETE: Relation Repair Fixed!")
print("="*60)

---
# PART 2: Human Evaluation Framework

Creates tools for human evaluation of response quality.

## Cell 6: Part 2 - Create Human Evaluation Samples

In [None]:
"""
Part 2: Create Human Evaluation Samples
Creates blinded samples for human evaluation.
"""

import json
import random
from datetime import datetime

random.seed(42)

# Create sample evaluation data
# In production, this would use your actual test data
print("Creating human evaluation samples...")

# Sample contexts from corpus
sample_contexts = []
for topic, responses in corpus.items():
    for resp in responses[:10]:  # 10 from each topic
        sample_contexts.append({
            "context": resp["context"],
            "evidence": "",
            "topic": topic
        })

random.shuffle(sample_contexts)
sample_contexts = sample_contexts[:200]  # Limit to 200

# Create samples with responses from different "systems"
all_samples = []

for i, ctx in enumerate(sample_contexts):
    context = ctx["context"]
    
    # Good response (from corpus)
    good_response = corpus[ctx["topic"]][0]["response"] if corpus[ctx["topic"]] else "Good response."
    all_samples.append({
        "context": context,
        "evidence": ctx["evidence"],
        "response": good_response,
        "system": "gricebench_repair"
    })
    
    # Off-topic response (Relation violation)
    other_topic = random.choice([t for t in corpus.keys() if t != ctx["topic"]])
    bad_response = corpus[other_topic][0]["response"] if corpus[other_topic] else "Off-topic response."
    all_samples.append({
        "context": context,
        "evidence": ctx["evidence"],
        "response": bad_response,
        "system": "original_violated"
    })

# Shuffle
random.shuffle(all_samples)

# Create blinded samples and key
blinded_samples = []
system_key = {}

for i, sample in enumerate(all_samples):
    system_key[str(i)] = sample["system"]
    blinded_samples.append({
        "id": i,
        "context": sample["context"],
        "evidence": sample["evidence"],
        "response": sample["response"]
    })

# Save
samples_path = '/kaggle/working/human_eval_samples.json'
key_path = '/kaggle/working/human_eval_key_DO_NOT_SHARE.json'

with open(samples_path, 'w') as f:
    json.dump(blinded_samples, f, indent=2)

with open(key_path, 'w') as f:
    json.dump(system_key, f, indent=2)

print(f"\n‚úÖ Created {len(blinded_samples)} blinded samples")
print(f"   Samples: {samples_path}")
print(f"   Key: {key_path} (DO NOT SHARE with annotators!)")

## Cell 7: Summary & Download Instructions

In [None]:
import os
from pathlib import Path

print("="*70)
print("üéâ PARTS 1 & 2 COMPLETE!")
print("="*70)

print("\nüìÅ FILES CREATED:")
print("-"*50)

files = [
    '/kaggle/working/data_processed/topical_corpus.json',
    '/kaggle/working/data_processed/faiss_index.pkl',
    '/kaggle/working/human_eval_samples.json',
    '/kaggle/working/human_eval_key_DO_NOT_SHARE.json',
]

for f in files:
    if Path(f).exists():
        size = Path(f).stat().st_size / 1024 / 1024
        print(f"‚úÖ {f.split('/')[-1]:40s} ({size:.2f} MB)")
    else:
        print(f"‚ùå {f} NOT FOUND")

print("\nüì• DOWNLOAD INSTRUCTIONS:")
print("-"*50)
print("1. Click on 'Output' tab on the right sidebar")
print("2. Download these files:")
print("   - data_processed/topical_corpus.json")
print("   - data_processed/faiss_index.pkl")
print("   - human_eval_samples.json")
print("   - human_eval_key_DO_NOT_SHARE.json")
print("")
print("3. Place them in your GriceBench folder:")
print("   GriceBench/")
print("   ‚îú‚îÄ‚îÄ data_processed/")
print("   ‚îÇ   ‚îú‚îÄ‚îÄ topical_corpus.json")
print("   ‚îÇ   ‚îî‚îÄ‚îÄ faiss_index.pkl")
print("   ‚îú‚îÄ‚îÄ human_eval_samples.json")
print("   ‚îî‚îÄ‚îÄ human_eval_key_DO_NOT_SHARE.json")

print("\nüöÄ NEXT STEPS:")
print("-"*50)
print("On your laptop:")
print("1. Run: python scripts/human_eval_gradio.py")
print("   (Opens web UI for human evaluation)")
print("")
print("2. After getting ratings, run:")
print("   python scripts/analyze_human_eval.py")
print("")
print("3. Continue to Part 3: Baseline Comparisons")