# DPO Data Generation - Self-Play & Auto-Scoring

Generate preference pairs (chosen/rejected) for DPO training using self-play generation and auto-scoring.

In [None]:
import subprocess
import sys

def install_package(package):
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])
        print(f"[install_package] Successfully installed {package}")
        return True
    except Exception as e:
        print(f"[install_package] Failed to install {package}: {e}")
        return False

required_packages = [
    "unsloth",
    "transformers>=4.36.0",
    "datasets",
    "torch",
    "trl>=0.7.0",
    "accelerate",
    "bitsandbytes"
]

print("========== START: Installing Packages ==========")
for pkg in required_packages:
    install_package(pkg)
print("========== END: Installing Packages ==========")

In [None]:
import os
import json
import torch
import random
import ast
import re
from pathlib import Path
from tqdm.auto import tqdm

try:
    from unsloth import FastLanguageModel
except ImportError:
    print("Unsloth not available, using transformers")
    FastLanguageModel = None

try:
    from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError as e:
    raise ImportError(f"transformers required: {e}")

try:
    from datasets import load_dataset
except ImportError as e:
    raise ImportError(f"datasets required: {e}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Generation will be slow.")

In [None]:
MODEL_PATH = "./final_model"
FIM_DATA_PATH = "fim_dataset.jsonl"
OUTPUT_PATH = "dpo_preference_data.jsonl"
NUM_SAMPLES = 5000
NUM_GENERATIONS = 5
MAX_NEW_TOKENS = 64
TEMPERATURE = 0.8

if not os.path.exists(MODEL_PATH):
    print(f"Model not found at {MODEL_PATH}")
    MODEL_PATH = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
    print(f"Using base model: {MODEL_PATH}")

if not os.path.exists(FIM_DATA_PATH):
    alt_paths = [
        "../phase1_data_engineering/fim_dataset.jsonl",
        "/content/fim_dataset.jsonl",
        "./train.jsonl"
    ]
    for p in alt_paths:
        if os.path.exists(p):
            FIM_DATA_PATH = p
            print(f"Found data at: {FIM_DATA_PATH}")
            break
    else:
        print(f"WARNING: No FIM data found. Please upload fim_dataset.jsonl")

In [None]:
PYTHON_CONVENTIONS = {
    "import pandas": ["as pd"],
    "import numpy": ["as np"],
    "import matplotlib.pyplot": ["as plt"],
    "import tensorflow": ["as tf"],
    "import seaborn": ["as sns"],
    "import sklearn": [],
    "from typing import": ["List", "Dict", "Optional", "Tuple", "Union", "Any"],
    "import torch": [],
    "import torch.nn": ["as nn"],
    "import torch.nn.functional": ["as F"],
    "from collections import": ["defaultdict", "Counter", "OrderedDict"],
    "import cv2": [],
    "from PIL import": ["Image"],
    "import requests": [],
    "import json": [],
    "import os": [],
    "import sys": [],
    "import re": [],
}

JAVA_CONVENTIONS = {
    "public static void": ["main(String[] args)", "main(String args[])"],
    "System.out.": ["println", "print", "printf"],
    "public class": [],
    "private ": [],
    "@Override": [],
    "import java.util.": ["List", "ArrayList", "Map", "HashMap", "Set"],
}

CPP_CONVENTIONS = {
    "#include <iostream>": [],
    "#include <vector>": [],
    "#include <string>": [],
    "using namespace": ["std"],
    "std::": ["cout", "cin", "endl", "vector", "string", "map"],
    "int main": ["()", "(int argc, char* argv[])"],
}

IRRELEVANT_IMPORTS = {
    "python": ["react", "vue", "angular", "express", "node", "jquery", "lodash"],
    "java": ["numpy", "pandas", "tensorflow", "react", "import os"],
    "cpp": ["numpy", "pandas", "import ", "from ", "react"],
}

In [None]:
def detect_language(text):
    if not text:
        return "unknown"
    
    text_lower = text.lower()
    
    if "import " in text or "def " in text or "class " in text and ":" in text:
        if "public class" not in text and "#include" not in text:
            return "python"
    
    if "public class" in text or "public static void" in text or "System.out" in text:
        return "java"
    
    if "#include" in text or "std::" in text or "int main" in text:
        return "cpp"
    
    return "python"


def check_syntax_python(code):
    try:
        ast.parse(code)
        return True
    except:
        return False


def check_syntax_partial(completion, language):
    if not completion or not completion.strip():
        return False
    
    completion = completion.strip()
    
    if language == "python":
        invalid_patterns = [
            r'^[\)\]\}]+$',
            r'\s{10,}',
            r'^[^a-zA-Z0-9_\s]+$',
        ]
        for pattern in invalid_patterns:
            if re.match(pattern, completion):
                return False
        return True
    
    return len(completion) > 0


def check_convention_match(prompt, completion, language):
    prompt = prompt.strip()
    completion = completion.strip()
    
    conventions = PYTHON_CONVENTIONS if language == "python" else \
                  JAVA_CONVENTIONS if language == "java" else CPP_CONVENTIONS
    
    for pattern, expected_list in conventions.items():
        if pattern in prompt:
            if not expected_list:
                return 0.5
            for expected in expected_list:
                if expected.lower() in completion.lower():
                    return 1.0
            return 0.2
    
    return 0.5


def check_hallucination(prompt, completion, language):
    completion_lower = completion.lower()
    irrelevant = IRRELEVANT_IMPORTS.get(language, [])
    
    for term in irrelevant:
        if term in completion_lower:
            return 0.0
    
    return 1.0


def check_length(completion):
    length = len(completion.split())
    
    if length < 1:
        return 0.0
    if length > 100:
        return 0.3
    if 1 <= length <= 50:
        return 1.0
    return 0.7


def check_context_relevance(prompt, completion):
    prompt_tokens = set(prompt.lower().split())
    completion_tokens = set(completion.lower().split())
    
    if not completion_tokens:
        return 0.0
    
    common = prompt_tokens.intersection(completion_tokens)
    
    relevance = min(len(common) / max(len(completion_tokens), 1), 1.0)
    
    return 0.3 + 0.7 * relevance

In [None]:
def score_completion(prompt, completion, language):
    # print(f"[score_completion] Input: completion_len={len(completion)}, lang={language}") # Commented to avoid noise
    scores = {}
    
    try:
        scores["syntax"] = 1.0 if check_syntax_partial(completion, language) else 0.0
    except Exception:
        scores["syntax"] = 0.5
    
    try:
        scores["convention"] = check_convention_match(prompt, completion, language)
    except Exception:
        scores["convention"] = 0.5
    
    try:
        scores["relevance"] = check_context_relevance(prompt, completion)
    except Exception:
        scores["relevance"] = 0.5
    
    try:
        scores["length"] = check_length(completion)
    except Exception:
        scores["length"] = 0.5
    
    try:
        scores["hallucination"] = check_hallucination(prompt, completion, language)
    except Exception:
        scores["hallucination"] = 0.5
    
    weights = {
        "syntax": 0.30,
        "convention": 0.25,
        "relevance": 0.25,
        "length": 0.10,
        "hallucination": 0.10,
    }
    
    final_score = sum(scores[k] * weights[k] for k in scores)
    
    # print(f"[score_completion] Output: final_score={final_score*100:.2f}")
    return final_score * 100, scores

In [None]:
def load_model(model_path):
    print(f"========== START: load_model ==========")
    print(f"[load_model] Input: model_path={model_path}")
    
    if FastLanguageModel is not None:
        try:
            model, tokenizer = FastLanguageModel.from_pretrained(
                model_name=model_path,
                max_seq_length=2048,
                dtype=None,
                load_in_4bit=True,
            )
            FastLanguageModel.for_inference(model)
            print("[load_model] Output: Loaded with Unsloth")
            print(f"========== END: load_model ==========")
            return model, tokenizer
        except Exception as e:
            print(f"[load_model] Unsloth failed: {e}. Trying transformers...")
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
        )
        model.eval()
        print("[load_model] Output: Loaded with transformers")
        print(f"========== END: load_model ==========")
        return model, tokenizer
    except Exception as e:
        raise RuntimeError(f"[load_model] Failed to load model: {e}")

In [None]:
def generate_completions(model, tokenizer, prompt, num_generations=5, max_tokens=64, temperature=0.8):
    # print(f"[generate_completions] Input: prompt_len={len(prompt)}, num_generations={num_generations}")
    completions = []
    
    try:
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
    except Exception as e:
        print(f"[generate_completions] Tokenization error: {e}")
        return []
    
    for i in range(num_generations):
        try:
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_tokens,
                    temperature=temperature,
                    do_sample=True,
                    top_p=0.95,
                    top_k=50,
                    pad_token_id=tokenizer.eos_token_id,
                    num_return_sequences=1,
                )
            
            generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
            completion = generated[len(prompt):].strip()
            
            if completion:
                completions.append(completion)
        except Exception as e:
            print(f"[generate_completions] Generation error {i}: {e}")
            continue
    
    # print(f"[generate_completions] Output: {len(completions)} completions generated")
    return completions

In [None]:
def load_fim_samples(data_path, num_samples):
    print(f"========== START: load_fim_samples ==========")
    print(f"[load_fim_samples] Input: path={data_path}, num_samples={num_samples}")
    samples = []
    
    try:
        with open(data_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i >= num_samples * 2:
                    break
                try:
                    data = json.loads(line.strip())
                    if 'text' in data and '<fim_middle>' in data['text']:
                        samples.append(data)
                except json.JSONDecodeError:
                    continue
    except FileNotFoundError:
        print(f"[load_fim_samples] File not found: {data_path}")
        return []
    except Exception as e:
        print(f"[load_fim_samples] Error loading data: {e}")
        return []
    
    random.shuffle(samples)
    print(f"[load_fim_samples] Output: {len(samples[:num_samples])} samples loaded")
    print(f"========== END: load_fim_samples ==========")
    return samples[:num_samples]


def extract_prompt_from_fim(fim_text):
    if '<fim_middle>' in fim_text:
        parts = fim_text.split('<fim_middle>')
        return parts[0] + '<fim_middle>'
    return fim_text

In [None]:
model, tokenizer = load_model(MODEL_PATH)

In [None]:
print(f"Loading {NUM_SAMPLES} samples from {FIM_DATA_PATH}...")
samples = load_fim_samples(FIM_DATA_PATH, NUM_SAMPLES)
print(f"Loaded {len(samples)} samples")

if len(samples) == 0:
    print("No samples loaded. Please check data path.")

In [None]:
preference_data = []
failed_count = 0
skip_count = 0

print(f"========== START: Preference Generation Loop ==========")
for i, sample in tqdm(enumerate(samples), total=len(samples), desc="Generating preference pairs"):
    try:
        fim_text = sample.get('text', '')
        prompt = extract_prompt_from_fim(fim_text)
        
        if not prompt or len(prompt) < 10:
            skip_count += 1
            continue
        
        language = detect_language(prompt)
        
        completions = generate_completions(
            model, tokenizer, prompt,
            num_generations=NUM_GENERATIONS,
            max_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE
        )
        print(f"[Loop] Sample {i}: Prompt='{prompt[:30]}...', Lang={language}, Generated={len(completions)}")
        
        if len(completions) < 2:
            skip_count += 1
            continue
        
        scored = []
        for comp in completions:
            score, details = score_completion(prompt, comp, language)
            scored.append((comp, score, details))
        
        scored.sort(key=lambda x: x[1], reverse=True)
        
        chosen = scored[0][0]
        rejected = scored[-1][0]
        chosen_score = scored[0][1]
        rejected_score = scored[-1][1]
        
        print(f"   Scores: Chosen={chosen_score:.1f}, Rejected={rejected_score:.1f}")
        
        if chosen == rejected or abs(chosen_score - rejected_score) < 5:
            print("   Skipping: Scores too close or identical")
            skip_count += 1
            continue
        
        preference_data.append({
            "prompt": prompt,
            "chosen": chosen,
            "rejected": rejected,
            "chosen_score": chosen_score,
            "rejected_score": rejected_score,
            "language": language,
        })
        
    except Exception as e:
        failed_count += 1
        if failed_count <= 5:
            print(f"[Loop] Error processing sample {i}: {e}")
        continue

print(f"========== END: Preference Generation Loop ==========")
print(f"\nGenerated {len(preference_data)} preference pairs")
print(f"Skipped: {skip_count}, Failed: {failed_count}")

In [None]:
if preference_data:
    print("\n========== Sample Preference Pair ==========")
    sample = preference_data[0]
    print(f"Prompt: {sample['prompt'][:100]}...")
    print(f"Chosen ({sample['chosen_score']:.1f}): {sample['chosen'][:80]}...")
    print(f"Rejected ({sample['rejected_score']:.1f}): {sample['rejected'][:80]}...")
    print("==========================================")

In [None]:
try:
    with open(OUTPUT_PATH, 'w', encoding='utf-8') as f:
        for item in preference_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    print(f"[Save] Successfully saved {len(preference_data)} pairs to {OUTPUT_PATH}")
except Exception as e:
    print(f"[Save] Error: {e}")
    backup_path = "/content/dpo_data_backup.jsonl"
    try:
        with open(backup_path, 'w', encoding='utf-8') as f:
            for item in preference_data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        print(f"[Save] Saved backup to {backup_path}")
    except:
        print("[Save] Failed to save data anywhere.")

In [None]:
if preference_data:
    scores = [p['chosen_score'] for p in preference_data]
    rejected_scores = [p['rejected_score'] for p in preference_data]
    
    print(f"\n========== Statistics ==========")
    print(f"Chosen scores - Mean: {sum(scores)/len(scores):.1f}, Min: {min(scores):.1f}, Max: {max(scores):.1f}")
    print(f"Rejected scores - Mean: {sum(rejected_scores)/len(rejected_scores):.1f}, Min: {min(rejected_scores):.1f}, Max: {max(rejected_scores):.1f}")
    
    lang_counts = {}
    for p in preference_data:
        lang = p.get('language', 'unknown')
        lang_counts[lang] = lang_counts.get(lang, 0) + 1
    print(f"Language distribution: {lang_counts}")
    print("===============================")