# üéØ STEP 1-3: Score, Filter, Merge Synthetic Data

**Pipeline:**
1. Load 4151 synthetic pairs
2. Score with Gricean reward models (4 maxims)
3. Filter for all margins > 0
4. Merge with 411 clean pairs
5. Save final DPO dataset

**Expected Output:** ~3,400 high-quality DPO pairs

**Setup:**
- GPU: T4 x2
- Datasets: `synthetic_candidates.json`, `clean_dpo_pairs.json`
- Models: Your 4 partial reward models

In [None]:
# Cell 1: Install & Import
!pip install -q transformers torch accelerate datasets

import json, torch, os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm.auto import tqdm

print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print("‚úÖ Ready")

In [None]:
# Cell 2: Configuration

# Input files
SYNTHETIC_FILE = None
CLEAN_FILE = None

# Try to find synthetic candidates
for p in ["/kaggle/input/synthetic-candidates/synthetic_candidates.json",
          "/kaggle/input/synthetic-data/synthetic_candidates.json"]:
    if os.path.exists(p): SYNTHETIC_FILE = p; break

# Try to find clean pairs
for p in ["/kaggle/input/clean-dpo-pairs/clean_dpo_pairs.json",
          "/kaggle/input/clean-pairs/clean_dpo_pairs.json"]:
    if os.path.exists(p): CLEAN_FILE = p; break

if not SYNTHETIC_FILE:
    raise FileNotFoundError("Upload synthetic_candidates.json!")
if not CLEAN_FILE:
    raise FileNotFoundError("Upload clean_dpo_pairs.json!")

OUTPUT_FILE = "/kaggle/working/final_dpo_dataset.json"

print(f"Synthetic: {SYNTHETIC_FILE}")
print(f"Clean: {CLEAN_FILE}")
print(f"Output: {OUTPUT_FILE}")

In [None]:
# Cell 3: Load Gricean Reward Models

print("Loading 4 Gricean reward models...\n")

# Model paths (update these to your HuggingFace paths)
MODEL_PATHS = {
    'quantity': 'Pushkar27/MaxMargin-RM-Partial-Quantity',
    'quality': 'Pushkar27/MaxMargin-RM-Partial-Quality',
    'relation': 'Pushkar27/MaxMargin-RM-Partial-Relation',
    'manner': 'Pushkar27/MaxMargin-RM-Partial-Manner'
}

models = {}
tokenizers = {}

for maxim, path in MODEL_PATHS.items():
    print(f"Loading {maxim}...")
    tokenizers[maxim] = AutoTokenizer.from_pretrained(path)
    models[maxim] = AutoModelForSequenceClassification.from_pretrained(
        path,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    models[maxim].eval()

print("\n‚úÖ All 4 models loaded")

In [None]:
# Cell 4: Scoring Functions

def score_response(prompt, response, maxim):
    """Score a single response for a given maxim"""
    text = f"{prompt}\n\nResponse: {response}"
    
    inputs = tokenizers[maxim](
        text,
        return_tensors="pt",
        truncation=True,
        max_length=512
    ).to(models[maxim].device)
    
    with torch.no_grad():
        outputs = models[maxim](**inputs)
        # Assuming logits[0] is the score
        score = outputs.logits[0][0].cpu().item()
    
    return score

def score_pair(prompt, chosen, rejected):
    """Score a DPO pair and return margins for all maxims"""
    margins = {}
    
    for maxim in ['quantity', 'quality', 'relation', 'manner']:
        chosen_score = score_response(prompt, chosen, maxim)
        rejected_score = score_response(prompt, rejected, maxim)
        margins[maxim] = chosen_score - rejected_score
    
    return margins

print("‚úÖ Scoring functions defined")

In [None]:
# Cell 5: STEP 1 - Score Synthetic Pairs

print("="*80)
print("STEP 1: SCORING SYNTHETIC PAIRS")
print("="*80)

# Load synthetic candidates
with open(SYNTHETIC_FILE) as f:
    synthetic_data = json.load(f)

print(f"\nLoaded {len(synthetic_data)} synthetic pairs")
print(f"Expected time: ~{len(synthetic_data) * 2 / 3600:.1f} hours (8 model calls per pair)\n")

scored_synthetic = []
stats = {'total': 0, 'errors': 0}

for i, item in enumerate(tqdm(synthetic_data, desc="Scoring")):
    try:
        # Score the pair: synthetic (chosen) vs original failed (rejected)
        margins = score_pair(
            prompt=item['prompt'],
            chosen=item['synthetic_chosen'],
            rejected=item['original_chosen_failed']
        )
        
        # Add margins to item
        item['synthetic_margins'] = margins
        scored_synthetic.append(item)
        stats['total'] += 1
        
        # Progress update every 100 items
        if (i+1) % 100 == 0:
            print(f"\nProgress: {i+1}/{len(synthetic_data)}")
            print(f"Errors: {stats['errors']}")
            print(f"Last margins: {margins}")
        
    except Exception as e:
        stats['errors'] += 1
        print(f"\nError at {i}: {str(e)[:100]}")
        if stats['errors'] > 50:
            print("\n‚ö†Ô∏è Too many errors, stopping...")
            break

print(f"\n‚úÖ Scored {stats['total']} pairs (Errors: {stats['errors']})")

In [None]:
# Cell 6: STEP 2 - Filter for Positive Margins

print("="*80)
print("STEP 2: FILTERING FOR POSITIVE MARGINS")
print("="*80)

# Filter: ALL margins must be > 0
filtered_synthetic = [
    item for item in scored_synthetic
    if all(item['synthetic_margins'][m] > 0 for m in ['quantity', 'quality', 'relation', 'manner'])
]

print(f"\nBefore filtering: {len(scored_synthetic)} pairs")
print(f"After filtering: {len(filtered_synthetic)} pairs")
print(f"Pass rate: {len(filtered_synthetic)/len(scored_synthetic)*100:.1f}%")

# Statistics on filtered data
if filtered_synthetic:
    avg_margins = {
        m: sum(item['synthetic_margins'][m] for item in filtered_synthetic) / len(filtered_synthetic)
        for m in ['quantity', 'quality', 'relation', 'manner']
    }
    print(f"\nAverage margins (filtered):")
    for m, val in avg_margins.items():
        print(f"  {m}: {val:.4f}")

print(f"\n‚úÖ Filtering complete: {len(filtered_synthetic)} high-quality synthetic pairs")

In [None]:
# Cell 7: STEP 3 - Merge with Clean Pairs

print("="*80)
print("STEP 3: MERGING WITH CLEAN PAIRS")
print("="*80)

# Load clean pairs
with open(CLEAN_FILE) as f:
    clean_data = json.load(f)

print(f"\nClean pairs: {len(clean_data)}")
print(f"Filtered synthetic: {len(filtered_synthetic)}")

# Prepare synthetic pairs in DPO format
synthetic_dpo = []
for item in filtered_synthetic:
    synthetic_dpo.append({
        'prompt': item['prompt'],
        'chosen': item['synthetic_chosen'],
        'rejected': item['original_chosen_failed'],
        'margins': item['synthetic_margins'],
        'source': 'synthetic'
    })

# Merge: Clean first, then synthetic
final_dataset = clean_data + synthetic_dpo

print(f"\n‚úÖ Final DPO dataset: {len(final_dataset)} pairs")
print(f"   - Clean (human): {len(clean_data)} ({len(clean_data)/len(final_dataset)*100:.1f}%)")
print(f"   - Synthetic: {len(synthetic_dpo)} ({len(synthetic_dpo)/len(final_dataset)*100:.1f}%)")

# Save final dataset
with open(OUTPUT_FILE, 'w') as f:
    json.dump(final_dataset, f, indent=2)

print(f"\n‚úÖ Saved to: {OUTPUT_FILE}")
print(f"   File size: {os.path.getsize(OUTPUT_FILE) / (1024**2):.2f} MB")

In [None]:
# Cell 8: Summary & Validation

print("="*80)
print("PIPELINE SUMMARY")
print("="*80)

print(f"\nüìä Processing Results:")
print(f"   Synthetic candidates: {len(synthetic_data)}")
print(f"   Successfully scored: {len(scored_synthetic)}")
print(f"   Passed strict filter: {len(filtered_synthetic)}")
print(f"   Filter pass rate: {len(filtered_synthetic)/len(scored_synthetic)*100:.1f}%")

print(f"\nüéØ Final Dataset:")
print(f"   Total pairs: {len(final_dataset)}")
print(f"   Human pairs: {len(clean_data)}")
print(f"   Synthetic pairs: {len(synthetic_dpo)}")
print(f"   Ratio: {len(synthetic_dpo)/len(clean_data):.1f}x synthetic data")

# Validate sample
print(f"\nüîç Sample from final dataset:")
sample = final_dataset[len(clean_data)]  # First synthetic item
print(f"   Prompt: {sample['prompt'][:100]}...")
print(f"   Chosen: {sample['chosen'][:100]}...")
print(f"   Margins: {sample['margins']}")

print(f"\n‚úÖ STEPS 1-3 COMPLETE")
print(f"\nüì• Download {OUTPUT_FILE} and proceed to DPO training!")
print(f"\nExpected DPO performance:")
print(f"   - Better than 411-only baseline (96.8% accuracy)")
print(f"   - More robust generalization")
print(f"   - Stronger alignment signal")