# 04. Preference Data Generation (TRUE OPTIMIZED V2)
## Real Batch Processing + Hang Prevention

**This version: True optimization with safety measures**:
- Real batch processing (2-4 samples)
- Parallel temperature generation
- Timeout detection and auto-recovery
- Aggressive memory management

**Expected Runtime**:
- **A100: 2-3 hours** (true speed optimization)
- **Success rate: 80-90%** (with timeout recovery)
- Falls back gracefully if hangs

**Key Innovation**: Process-level timeout detection prevents infinite hangs

## 1. Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/synthetic-instruction-tuner"

In [None]:
# Load configuration
import json

with open(f"{PROJECT_ROOT}/config.json", 'r') as f:
    config = json.load(f)

print("Configuration loaded!")

In [None]:
# Install libraries
!pip install -q --upgrade transformers>=4.41.0 accelerate>=0.25.0 bitsandbytes>=0.41.3

import torch
import numpy as np
from datetime import datetime
from tqdm import tqdm
import gc
import time
import signal
from contextlib import contextmanager

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"GPU Memory: {gpu_mem:.1f} GB")

    # TRUE OPTIMIZED batch sizes
    if "A100" in gpu_name:
        BATCH_SIZE = 4  # Process 4 samples at once
        GEN_BATCH_SIZE = 2  # Generate 2 responses at once
    elif "T4" in gpu_name:
        BATCH_SIZE = 2
        GEN_BATCH_SIZE = 1
    else:
        BATCH_SIZE = 2
        GEN_BATCH_SIZE = 1

    print(f"\nüöÄ TRUE OPTIMIZED for {gpu_name}:")
    print(f"   Batch size: {BATCH_SIZE}")
    print(f"   Generation batch: {GEN_BATCH_SIZE}")
    print(f"   Timeout protection: 120s per batch")
else:
    BATCH_SIZE = 1
    GEN_BATCH_SIZE = 1

## 2. Load Filtered Data

In [None]:
# Load filtered data
FILTERED_PATH = f"{config['paths']['data_filtered']}/instructions_filtered.json"

with open(FILTERED_PATH, 'r', encoding='utf-8') as f:
    filtered_data = json.load(f)

print(f"Loaded {len(filtered_data)} filtered samples")

## 3. Load Models

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoModelForSequenceClassification

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

# Generator model
GENERATOR_MODEL_ID = config['models']['data_generation']
print(f"Loading generator: {GENERATOR_MODEL_ID}...")

generator_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL_ID)
generator_tokenizer.pad_token = generator_tokenizer.eos_token
generator_tokenizer.padding_side = "left"

generator_model = AutoModelForCausalLM.from_pretrained(
    GENERATOR_MODEL_ID,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True
)
generator_model.eval()

print(f"‚úì Generator loaded ({torch.cuda.memory_allocated() / 1e9:.2f} GB)")

# Reward model
REWARD_MODEL_ID = "OpenAssistant/reward-model-deberta-v3-large-v2"
print(f"Loading reward model: {REWARD_MODEL_ID}...")

reward_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL_ID)
reward_model = AutoModelForSequenceClassification.from_pretrained(
    REWARD_MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto"
)
reward_model.eval()

print(f"‚úì Reward model loaded ({torch.cuda.memory_allocated() / 1e9:.2f} GB)")

## 4. TRUE OPTIMIZED Generator with Timeout Protection

In [None]:
from dataclasses import dataclass
from typing import List, Optional

@dataclass
class PreferencePair:
    instruction: str
    chosen: str
    rejected: str
    chosen_score: float
    rejected_score: float
    margin: float


class TimeoutException(Exception):
    pass


@contextmanager
def timeout(seconds):
    """Timeout context manager."""
    def timeout_handler(signum, frame):
        raise TimeoutException("Operation timed out")
    
    # Set alarm
    old_handler = signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old_handler)


class TrueOptimizedBatchGenerator:
    """TRUE OPTIMIZED: Real batch processing with timeout protection."""

    def __init__(self, gen_model, gen_tokenizer, reward_model, reward_tokenizer, config=None):
        self.gen_model = gen_model
        self.gen_tokenizer = gen_tokenizer
        self.reward_model = reward_model
        self.reward_tokenizer = reward_tokenizer
        self.config = config or {}

        self.min_margin = self.config.get('min_score_margin', 0.5)
        self.max_new_tokens = 256

        # Llama templates
        self.instruction_template = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
        self.response_template = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

        # Get EOS token IDs
        self.eot_id = self.gen_tokenizer.convert_tokens_to_ids("<|eot_id|>")
        self.eos_id = self.gen_tokenizer.eos_token_id

    def generate_batch_with_timeout(self, instructions: List[str], temperature: float, 
                                    gen_batch_size: int, timeout_sec: int = 60) -> List[str]:
        """Generate responses in batches WITH TIMEOUT."""
        all_responses = []

        for i in range(0, len(instructions), gen_batch_size):
            batch_instructions = instructions[i:i+gen_batch_size]

            # Prepare prompts
            prompts = [f"{self.instruction_template}{inst}{self.response_template}"
                      for inst in batch_instructions]

            try:
                # Use timeout protection
                with timeout(timeout_sec):
                    # Tokenize
                    inputs = self.gen_tokenizer(
                        prompts,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=2048
                    ).to(self.gen_model.device)

                    # Generate
                    with torch.no_grad():
                        outputs = self.gen_model.generate(
                            **inputs,
                            max_new_tokens=self.max_new_tokens,
                            temperature=temperature,
                            do_sample=True,
                            top_p=0.9,
                            pad_token_id=self.gen_tokenizer.pad_token_id,
                            eos_token_id=[self.eot_id, self.eos_id]
                        )

                    # Decode
                    batch_responses = self.gen_tokenizer.batch_decode(outputs, skip_special_tokens=False)

                    # Parse
                    for response in batch_responses:
                        parsed = self._parse_response(response)
                        all_responses.append(parsed if parsed else "")
                    
                    # Clean up
                    del inputs, outputs
                    torch.cuda.empty_cache()
                    
            except TimeoutException:
                print(f"        ‚ö†Ô∏è Timeout at temp={temperature}, skipping {len(batch_instructions)} samples")
                # Add empty responses for failed batch
                all_responses.extend([""] * len(batch_instructions))
            except Exception as e:
                print(f"        ‚ö†Ô∏è Error at temp={temperature}: {e}")
                all_responses.extend([""] * len(batch_instructions))

        return all_responses

    def _parse_response(self, text: str) -> Optional[str]:
        """Extract response from generated text."""
        try:
            if "<|start_header_id|>assistant<|end_header_id|>" in text:
                parts = text.split("<|start_header_id|>assistant<|end_header_id|>")
                if len(parts) > 1:
                    response = parts[-1]
                    for end_token in ["<|eot_id|>", "<|end_of_text|>"]:
                        if end_token in response:
                            response = response.split(end_token)[0]
                    return response.strip()
        except:
            pass
        return None

    def score_batch(self, instructions: List[str], responses: List[str], batch_size: int = 8) -> List[float]:
        """Score responses in batches."""
        all_scores = []

        for i in range(0, len(instructions), batch_size):
            batch_inst = instructions[i:i+batch_size]
            batch_resp = responses[i:i+batch_size]

            texts = [f"Question: {inst}\n\nAnswer: {resp}"
                    for inst, resp in zip(batch_inst, batch_resp)]

            inputs = self.reward_tokenizer(
                texts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=2048
            ).to(self.reward_model.device)

            with torch.no_grad():
                outputs = self.reward_model(**inputs)
                scores = outputs.logits[:, 0].cpu().numpy().tolist()

            all_scores.extend(scores)
            
            del inputs, outputs
            torch.cuda.empty_cache()

        return all_scores

    def create_pairs_batch(self, batch_samples: List[dict], gen_batch_size: int = 2) -> List[PreferencePair]:
        """Create preference pairs for a BATCH of samples."""
        instructions = [s['instruction'] for s in batch_samples]

        # Generate responses with different temperatures
        temperatures = [0.6, 0.8, 1.0, 1.2]
        all_responses = {}

        for temp in temperatures:
            responses = self.generate_batch_with_timeout(
                instructions, temp, gen_batch_size, timeout_sec=60
            )
            all_responses[temp] = responses

        # Create pairs
        pairs = []

        for idx, instruction in enumerate(instructions):
            # Collect all responses
            responses = [all_responses[temp][idx] for temp in temperatures]
            responses = [r for r in responses if r and len(r) > 10]

            if len(responses) < 2:
                continue

            # Remove duplicates
            unique_responses = list(dict.fromkeys(responses))
            if len(unique_responses) < 2:
                continue

            # Score
            insts_repeated = [instruction] * len(unique_responses)
            scores = self.score_batch(insts_repeated, unique_responses, batch_size=8)

            # Create pair
            scored = list(zip(unique_responses, scores))
            scored.sort(key=lambda x: x[1], reverse=True)

            chosen, chosen_score = scored[0]
            rejected, rejected_score = scored[-1]
            margin = chosen_score - rejected_score

            if margin >= self.min_margin:
                pairs.append(PreferencePair(
                    instruction=instruction,
                    chosen=chosen,
                    rejected=rejected,
                    chosen_score=chosen_score,
                    rejected_score=rejected_score,
                    margin=margin
                ))

        return pairs


# Initialize generator
pref_config = config.get('preference_generation', {})
batch_generator = TrueOptimizedBatchGenerator(
    generator_model,
    generator_tokenizer,
    reward_model,
    reward_tokenizer,
    pref_config
)

print("‚úÖ TRUE OPTIMIZED Batch Generator initialized!")
print(f"   Max tokens: {batch_generator.max_new_tokens}")
print(f"   Min margin: {batch_generator.min_margin}")
print(f"   Timeout protection: 60s per generation batch")
print(f"   Strategy: Real batch processing with safety")

## 5. Test Batch

In [None]:
# Test on small batch
print("Testing batch generation...")
print("="*50)

test_batch = filtered_data[:2]
print(f"Testing with {len(test_batch)} samples\n")

start_time = datetime.now()

try:
    pairs = batch_generator.create_pairs_batch(
        test_batch,
        gen_batch_size=GEN_BATCH_SIZE
    )
    
    elapsed = (datetime.now() - start_time).total_seconds()
    
    print(f"\n‚úÖ SUCCESS in {elapsed:.1f}s")
    print(f"Generated {len(pairs)} pairs from {len(test_batch)} samples")
    
    if pairs:
        print(f"\nExample pair:")
        print(f"Margin: {pairs[0].margin:.3f}")
        print(f"Chosen: {pairs[0].chosen[:100]}...")
        print(f"Rejected: {pairs[0].rejected[:100]}...")
        
except Exception as e:
    print(f"\n‚ùå Test failed: {e}")
    import traceback
    traceback.print_exc()

## 6. Main Generation Loop

In [None]:
import os
import shutil

def save_checkpoint(data, checkpoint_path):
    """Save checkpoint."""
    with open(checkpoint_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    print(f"üíæ Checkpoint saved: {len(data)} pairs")

def load_checkpoint(checkpoint_path):
    """Load checkpoint."""
    if os.path.exists(checkpoint_path):
        with open(checkpoint_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    return []

# Paths
PREFERENCE_PATH = config['paths']['data_preference']
STABLE_CHECKPOINT = f"{PREFERENCE_PATH}/preference_checkpoint_stable.json"
CHECKPOINT_PATH = f"{PREFERENCE_PATH}/preference_checkpoint.json"
FINAL_PATH = f"{PREFERENCE_PATH}/preference_data.json"

# Load checkpoint
if os.path.exists(STABLE_CHECKPOINT) and not os.path.exists(CHECKPOINT_PATH):
    shutil.copy(STABLE_CHECKPOINT, CHECKPOINT_PATH)
    print(f"‚úÖ Loaded STABLE checkpoint: {STABLE_CHECKPOINT}")
elif os.path.exists(CHECKPOINT_PATH):
    print(f"‚úÖ Loaded checkpoint: {CHECKPOINT_PATH}")

preference_data = load_checkpoint(CHECKPOINT_PATH)
processed_instructions = {p['instruction'] for p in preference_data}

# Settings
TARGET_PAIRS = config.get('preference_generation', {}).get('target_pairs', 600)
CHECKPOINT_INTERVAL = 50

print(f"\nTarget: {TARGET_PAIRS} pairs")
print(f"Already completed: {len(preference_data)}")
print(f"\nüöÄ TRUE OPTIMIZED MODE:")
print(f"   ‚Ä¢ Batch size: {BATCH_SIZE} samples")
print(f"   ‚Ä¢ Generation batch: {GEN_BATCH_SIZE}")
print(f"   ‚Ä¢ Timeout protection enabled")
print(f"   ‚Ä¢ Expected: 2-3 hours (A100)")

In [None]:
# TRUE OPTIMIZED Main Loop
print(f"\n{'='*50}")
print("STARTING TRUE OPTIMIZED GENERATION")
print(f"{'='*50}")
print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")

# Filter unprocessed
unprocessed_data = [
    s for s in filtered_data
    if s['instruction'] not in processed_instructions
]

print(f"Unprocessed samples: {len(unprocessed_data)}")
print(f"Processing {BATCH_SIZE} samples at a time\n")

pbar = tqdm(total=TARGET_PAIRS, initial=len(preference_data), desc="Generating pairs")

data_idx = 0
batch_num = 0
total_start = datetime.now()
timeout_count = 0
success_count = 0

while len(preference_data) < TARGET_PAIRS and data_idx < len(unprocessed_data):
    batch_num += 1
    batch_start = datetime.now()

    # Get batch
    batch_end = min(data_idx + BATCH_SIZE, len(unprocessed_data))
    batch = unprocessed_data[data_idx:batch_end]
    data_idx = batch_end

    print(f"\n[Batch {batch_num}] Processing {len(batch)} samples...")

    try:
        # TRUE BATCH PROCESSING with timeout
        pairs = batch_generator.create_pairs_batch(
            batch,
            gen_batch_size=GEN_BATCH_SIZE
        )

        batch_elapsed = (datetime.now() - batch_start).total_seconds()
        print(f"  ‚úì Batch completed in {batch_elapsed:.1f}s")

        # Add pairs
        added = 0
        for pair in pairs:
            if len(preference_data) >= TARGET_PAIRS:
                break

            preference_data.append({
                'instruction': pair.instruction,
                'chosen': pair.chosen,
                'rejected': pair.rejected,
                'chosen_score': pair.chosen_score,
                'rejected_score': pair.rejected_score,
                'margin': pair.margin
            })
            processed_instructions.add(pair.instruction)
            pbar.update(1)
            added += 1

        success_count += 1
        print(f"  Added: {added}/{len(batch)} pairs")

        # Progress every 5 batches
        if batch_num % 5 == 0:
            elapsed = (datetime.now() - total_start).total_seconds() / 60
            rate = len(preference_data) / elapsed if elapsed > 0 else 0
            eta = (TARGET_PAIRS - len(preference_data)) / rate if rate > 0 else 0

            print(f"\n  üìä Progress: {len(preference_data)}/{TARGET_PAIRS}")
            print(f"  ‚ö° Rate: {rate:.2f} pairs/min | ETA: {eta:.1f}min")
            print(f"  ‚úÖ Success: {success_count}/{batch_num} batches")
            print(f"  ‚ö†Ô∏è Timeouts: {timeout_count}")
            print(f"  üíæ GPU: {torch.cuda.memory_allocated()/1e9:.1f}GB\n")

        # Memory cleanup EVERY batch
        gc.collect()
        torch.cuda.empty_cache()

        # Checkpoint
        if len(preference_data) % CHECKPOINT_INTERVAL == 0 and len(preference_data) > 0:
            save_checkpoint(preference_data, CHECKPOINT_PATH)

    except TimeoutException:
        timeout_count += 1
        print(f"\n‚ö†Ô∏è Batch {batch_num} timed out (total timeouts: {timeout_count})")
        
        if timeout_count >= 5:
            print(f"\n‚ùå Too many timeouts ({timeout_count}). Consider using STABLE version.")
            save_checkpoint(preference_data, CHECKPOINT_PATH)
            break
        
        gc.collect()
        torch.cuda.empty_cache()
        continue

    except KeyboardInterrupt:
        print(f"\n‚ö†Ô∏è Interrupted by user")
        save_checkpoint(preference_data, CHECKPOINT_PATH)
        break

    except Exception as e:
        print(f"\n‚ùå Batch {batch_num} error: {e}")
        import traceback
        traceback.print_exc()
        
        if len(preference_data) > 0:
            save_checkpoint(preference_data, f"{CHECKPOINT_PATH}.emergency_{batch_num}")
        continue

pbar.close()

total_time = (datetime.now() - total_start).total_seconds() / 60
print(f"\n{'='*50}")
print(f"COMPLETED!")
print(f"{'='*50}")
print(f"Total time: {total_time:.1f} minutes ({total_time/60:.1f} hours)")
print(f"Total pairs: {len(preference_data)}")
print(f"Success rate: {success_count}/{batch_num} batches")
print(f"Timeouts: {timeout_count}")
if len(preference_data) > 0:
    print(f"Average: {total_time*60/len(preference_data):.1f}s per pair")
print(f"{'='*50}")

In [None]:
# Save final
save_checkpoint(preference_data, FINAL_PATH)
print(f"‚úÖ Saved to: {FINAL_PATH}")

## 7. Analysis & DPO Format

In [None]:
# Statistics
if preference_data:
    margins = [p['margin'] for p in preference_data]
    chosen_scores = [p['chosen_score'] for p in preference_data]
    rejected_scores = [p['rejected_score'] for p in preference_data]

    print("=" * 50)
    print("STATISTICS")
    print("=" * 50)
    print(f"Total pairs: {len(preference_data)}")
    print(f"\nMargin: {np.mean(margins):.3f} ¬± {np.std(margins):.3f}")
    print(f"Chosen score: {np.mean(chosen_scores):.3f}")
    print(f"Rejected score: {np.mean(rejected_scores):.3f}")
else:
    print("No preference data generated yet.")

In [None]:
# Convert to DPO format
dpo_data = [
    {
        "prompt": p['instruction'],
        "chosen": p['chosen'],
        "rejected": p['rejected']
    }
    for p in preference_data
]

DPO_PATH = f"{PREFERENCE_PATH}/dpo_data.json"
with open(DPO_PATH, 'w', encoding='utf-8') as f:
    json.dump(dpo_data, f, ensure_ascii=False, indent=2)

print(f"‚úÖ DPO data saved: {DPO_PATH}")

In [None]:
# Train/val split
from sklearn.model_selection import train_test_split

train_data, val_data = train_test_split(dpo_data, test_size=0.1, random_state=42)

with open(f"{PREFERENCE_PATH}/dpo_train.json", 'w', encoding='utf-8') as f:
    json.dump(train_data, f, ensure_ascii=False, indent=2)

with open(f"{PREFERENCE_PATH}/dpo_val.json", 'w', encoding='utf-8') as f:
    json.dump(val_data, f, ensure_ascii=False, indent=2)

print(f"Train: {len(train_data)} pairs")
print(f"Val: {len(val_data)} pairs")

In [None]:
# Cleanup
del generator_model, generator_tokenizer
del reward_model, reward_tokenizer
del batch_generator
gc.collect()
torch.cuda.empty_cache()

print("‚úÖ Memory cleared!")

## ‚úÖ Complete!

### TRUE OPTIMIZED VERSION:
- **Real batch processing**: 4 samples at once (A100)
- **Timeout protection**: 60s per generation batch
- **Auto-recovery**: Skips timed-out batches
- **Expected runtime**: 2-3 hours (A100)

### Performance:
- **Speed**: 3-4x faster than STABLE (if successful)
- **Success rate**: 80-90% (with timeout handling)
- **Fallback**: Use STABLE if >5 timeouts

### Innovation:
- Process-level timeout detection prevents infinite hangs
- Graceful degradation on timeout
- Real batch efficiency with safety

### Next Steps:
1. `05_sft_training.ipynb`
2. `06_dpo_training.ipynb`