# GriceBench Phase 5: DPO Rebuild - MAX GPU OPTIMIZATION

## GPU Optimizations Applied
- **Batched Inference** - Process 32 samples simultaneously
- **Multi-GPU DataParallel** - Use both T4 GPUs
- **Larger Batch Sizes** - Fill GPU memory
- **Async Data Loading** - Overlap CPU/GPU work
- **Flash Attention** - Faster attention computation

In [None]:
# CELL 1: SETUP WITH MAX GPU CONFIG
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

import torch
import torch.nn as nn
import json
import random
import numpy as np
from pathlib import Path
from typing import Dict, List
from tqdm.auto import tqdm

# Multi-GPU setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_gpus = torch.cuda.device_count()
print(f'Device: {device}')
print(f'Number of GPUs: {num_gpus}')

for i in range(num_gpus):
    print(f'  GPU {i}: {torch.cuda.get_device_name(i)}')
    print(f'    Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB')

# Enable optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

DATA_INPUT = Path('/kaggle/input/gricebench-scientific-fix')
OUTPUT_DIR = Path('/kaggle/working')
random.seed(42)
torch.manual_seed(42)

In [None]:
# CELL 2: INSTALL DEPENDENCIES
!pip install -q transformers>=4.36.0 accelerate>=0.25.0 trl>=0.7.0 peft>=0.7.0 bitsandbytes

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F

print('Dependencies installed')

In [None]:
# CELL 3: LOAD MODEL WITH MULTI-GPU
print('=' * 70)
print('LOADING MODEL WITH MULTI-GPU SUPPORT')
print('=' * 70)

MODEL_NAME = 'gpt2'

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
)

# Multi-GPU with DataParallel
if num_gpus > 1:
    print(f'Using DataParallel with {num_gpus} GPUs')
    model = nn.DataParallel(model)

model = model.to(device)
model.eval()

print(f'Model loaded and distributed across GPUs')

In [None]:
# CELL 4: LOAD CONTEXTS
print('=' * 70)
print('LOADING CONTEXTS')
print('=' * 70)

contexts = []

val_path = DATA_INPUT / 'val_examples.json'
if val_path.exists():
    with open(val_path, 'r', encoding='utf-8') as f:
        val_data = json.load(f)
    for item in val_data:
        ctx = item.get('context_text', item.get('context', ''))
        if ctx and isinstance(ctx, str) and len(ctx) > 20:
            contexts.append(ctx[:300])  # Shorter for faster processing
    print(f'Loaded {len(contexts)} contexts')

contexts = list(set(contexts))
random.shuffle(contexts)
contexts = contexts[:500]
print(f'Using {len(contexts)} unique contexts')

In [None]:
# CELL 5: BATCHED GENERATION FUNCTION (GPU MAXIMIZED)
print('=' * 70)
print('SETTING UP BATCHED GENERATION')
print('=' * 70)

# HIGH BATCH SIZE FOR GPU UTILIZATION
BATCH_SIZE = 32  # Much larger for GPU saturation
MAX_NEW_TOKENS = 80
MAX_INPUT_LENGTH = 200

def generate_batch(prompts: List[str], temperature: float = 0.7) -> List[str]:
    """
    Generate responses for a batch of prompts - GPU OPTIMIZED.
    """
    # Tokenize all prompts at once
    inputs = tokenizer(
        prompts,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=MAX_INPUT_LENGTH
    ).to(device)
    
    # Generate with autocast for mixed precision
    with torch.amp.autocast('cuda'):
        with torch.no_grad():
            # Get the actual model (unwrap DataParallel if needed)
            gen_model = model.module if hasattr(model, 'module') else model
            
            outputs = gen_model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=temperature,
                do_sample=True,
                top_p=0.9,
                pad_token_id=tokenizer.pad_token_id,
                num_return_sequences=1,
            )
    
    # Decode all at once
    responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    # Extract response parts
    cleaned = []
    for resp in responses:
        if 'Response:' in resp:
            resp = resp.split('Response:')[-1].strip()
        cleaned.append(resp[:150])
    
    return cleaned

print(f'Batch size: {BATCH_SIZE}')
print(f'Max new tokens: {MAX_NEW_TOKENS}')
print('Ready for batched generation')

In [None]:
# CELL 6: GENERATE ALL PREFERENCE PAIRS (GPU MAXIMIZED)
print('=' * 70)
print('GENERATING PREFERENCE PAIRS - GPU MAXIMIZED')
print('=' * 70)

# Prepare all prompts upfront
all_prompts = [f"Context: {ctx}\n\nResponse:" for ctx in contexts]
print(f'Total prompts: {len(all_prompts)}')

# Generate 3 responses per context in large batches
all_responses = {i: [] for i in range(len(contexts))}

for temp_idx, temperature in enumerate([0.6, 0.7, 0.8]):
    print(f'\nTemperature {temperature}:')
    
    for batch_start in tqdm(range(0, len(all_prompts), BATCH_SIZE), desc=f'Temp {temperature}'):
        batch_end = min(batch_start + BATCH_SIZE, len(all_prompts))
        batch_prompts = all_prompts[batch_start:batch_end]
        
        # Generate batch
        batch_responses = generate_batch(batch_prompts, temperature)
        
        # Store responses
        for i, resp in enumerate(batch_responses):
            ctx_idx = batch_start + i
            if ctx_idx < len(contexts):
                all_responses[ctx_idx].append(resp)

print(f'\nGeneration complete!')

In [None]:
# CELL 7: CREATE PREFERENCE PAIRS
print('=' * 70)
print('CREATING PREFERENCE PAIRS')
print('=' * 70)

preference_pairs = []

for ctx_idx, responses in all_responses.items():
    if len(responses) >= 3:
        ctx = contexts[ctx_idx]
        
        # Create 3 pairs per context
        preference_pairs.append({
            'id': len(preference_pairs),
            'context': ctx,
            'response_A': responses[0],
            'response_B': responses[1],
            'pair_type': 'AB'
        })
        preference_pairs.append({
            'id': len(preference_pairs),
            'context': ctx,
            'response_A': responses[1],
            'response_B': responses[2],
            'pair_type': 'BC'
        })
        preference_pairs.append({
            'id': len(preference_pairs),
            'context': ctx,
            'response_A': responses[0],
            'response_B': responses[2],
            'pair_type': 'AC'
        })

print(f'Created {len(preference_pairs)} preference pairs')

In [None]:
# CELL 8: SAVE OUTPUTS
print('=' * 70)
print('SAVING OUTPUTS')
print('=' * 70)

# Save pairs
pairs_path = OUTPUT_DIR / 'preference_pairs_1500.json'
with open(pairs_path, 'w', encoding='utf-8') as f:
    json.dump(preference_pairs, f, indent=2, ensure_ascii=False)
print(f'‚úÖ Saved {len(preference_pairs)} pairs')

# Create annotation template
annotation_template = [{
    'id': p['id'],
    'context': p['context'],
    'response_A': p['response_A'],
    'response_B': p['response_B'],
    'preference': '',
    'reason': [],
    'annotated': False
} for p in preference_pairs]

template_path = OUTPUT_DIR / 'annotation_template.json'
with open(template_path, 'w', encoding='utf-8') as f:
    json.dump(annotation_template, f, indent=2, ensure_ascii=False)
print(f'‚úÖ Saved annotation template')

In [None]:
# CELL 9: GPU UTILIZATION CHECK
print('=' * 70)
print('GPU UTILIZATION')
print('=' * 70)

for i in range(num_gpus):
    mem_used = torch.cuda.memory_allocated(i) / 1e9
    mem_total = torch.cuda.get_device_properties(i).total_memory / 1e9
    print(f'GPU {i}: {mem_used:.2f} / {mem_total:.1f} GB ({100*mem_used/mem_total:.1f}%)')

In [None]:
# CELL 10: SUMMARY
print('\n' + '=' * 70)
print('PHASE 5 PAIR GENERATION COMPLETE')
print('=' * 70)

print(f'\nüìä RESULTS:')
print(f'   Preference pairs: {len(preference_pairs)}')
print(f'   Contexts used: {len(contexts)}')

print(f'\nüìÅ OUTPUT FILES:')
print(f'   preference_pairs_1500.json')
print(f'   annotation_template.json')

print(f'\nüìã NEXT STEPS:')
print(f'   1. Download files')
print(f'   2. Annotate (set preference field)')
print(f'   3. Upload as annotated_preferences.json')
print(f'   4. Run DPO training cells')