In [None]:
# Import all necessary libraries for PPO alignment
import torch
import os
from datasets import load_dataset
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification
)
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
from peft import PeftModel
from tqdm import tqdm
import time
import json
import numpy as np

print("All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print(f"Current GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)


In [None]:
# Comprehensive Configuration for PPO Experiments
print("Setting up PPO experiment configuration...")

# Model paths
sft_model_path = './models/sft'
rm_model_base_path = './models/rm'

# Precision levels to experiment with
rm_precisions_to_run = ['bf16', 'int8', 'int4']

# PPO Configuration optimized for RTX 4060
ppo_config_dict = {
    'learning_rate': 1.41e-5,           # Lower learning rate for stable PPO
    'batch_size': 64,                   # Total batch size for PPO
    'mini_batch_size': 4,               # Small mini batch for 8GB VRAM
    'gradient_accumulation_steps': 4,    # Effective mini batch = 4*4 = 16
    'ppo_epochs': 4,                    # Number of PPO epochs per batch
    'max_grad_norm': 0.5,               # Gradient clipping
    'kl_penalty': 'kl',                 # KL penalty type
    'adap_kl_ctrl': True,               # Adaptive KL controller for stability
    'init_kl_coef': 0.1,                # Initial KL coefficient
    'target_kl': 6.0,                   # Target KL divergence
    'gamma': 1.0,                       # Discount factor
    'lam': 0.95,                        # GAE lambda
    'cliprange': 0.2,                   # PPO clip range
    'cliprange_value': 0.2,             # Value function clip range
    'vf_coef': 0.1,                     # Value function coefficient
    'forward_batch_size': 16,           # Forward pass batch size
    'response_length': 128,             # Maximum response length
}

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Experiment tracking
experiment_results = {}

print("Configuration Summary:")
print(f"  SFT model path: {sft_model_path}")
print(f"  RM base path: {rm_model_base_path}")
print(f"  Precisions to test: {rm_precisions_to_run}")
print(f"  Device: {device}")
print(f"  PPO batch size: {ppo_config_dict['batch_size']}")
print(f"  PPO mini batch size: {ppo_config_dict['mini_batch_size']}")
print(f"  Effective mini batch: {ppo_config_dict['mini_batch_size'] * ppo_config_dict['gradient_accumulation_steps']}")
print(f"  Response length: {ppo_config_dict['response_length']}")
print("✅ Configuration complete!")


In [None]:
# Load prompt dataset and tokenizer
print("Loading prompts and tokenizer for PPO training...")

# Load test prompts to avoid overfitting to training set
test_dataset = load_dataset('json', data_files='./data/test_prefs.jsonl')['train']
print(f"Loaded {len(test_dataset)} test examples")

# Extract only prompts for PPO training
prompts_dataset = test_dataset.select_columns(['prompt'])
print(f"Extracted {len(prompts_dataset)} prompts")

# Take a subset for faster experimentation (optional)
max_prompts = min(200, len(prompts_dataset))  # Use up to 200 prompts
prompts_dataset = prompts_dataset.select(range(max_prompts))
print(f"Using {len(prompts_dataset)} prompts for PPO training")

# Load tokenizer from SFT model
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(sft_model_path)

# Set pad token if not exists
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print("✓ Pad token set to EOS token")

print(f"Tokenizer loaded - Vocab size: {len(tokenizer)}")

# Display sample prompts
print("\nSample prompts for PPO training:")
for i in range(min(3, len(prompts_dataset))):
    prompt_text = prompts_dataset[i]['prompt']
    print(f"  {i+1}. {prompt_text[:100]}...")

print("✅ Dataset and tokenizer ready!")


In [None]:
# PPO Pipeline Function - The Heart of Our Experiment
def run_ppo_experiment(rm_precision, config, tokenizer, dataset):
    """
    Run a complete PPO experiment with the specified reward model precision.
    
    Args:
        rm_precision: 'bf16', 'int8', or 'int4'
        config: Dictionary containing PPO configuration
        tokenizer: Pre-loaded tokenizer
        dataset: Prompts dataset for training
        
    Returns:
        dict: Experiment results and statistics
    """
    print(f"\n🚀 Starting PPO experiment with {rm_precision.upper()} reward model")
    start_time = time.time()
    
    # === STEP 1: Load Reward Model (on CPU to save VRAM) ===
    print("  Loading reward model on CPU...")
    rm_path = os.path.join(config['rm_model_base_path'], rm_precision)
    
    try:
        reward_model = AutoModelForSequenceClassification.from_pretrained(
            rm_path,
            device_map='cpu',  # Keep RM on CPU to save GPU memory
            torch_dtype=torch.float16
        )
        print(f"  ✓ Reward model loaded from {rm_path}")
    except Exception as e:
        print(f"  ❌ Error loading reward model: {e}")
        return {'status': 'failed', 'error': str(e)}
    
    # === STEP 2: Load Policy Model (SFT + Value Head on GPU) ===
    print("  Loading policy model on GPU...")
    try:
        # Load policy model with value head for PPO
        policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
            config['sft_model_path'],
            device_map='auto',
            torch_dtype=torch.bfloat16
        )
        
        # Load reference model (frozen SFT model)
        ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
            config['sft_model_path'],
            device_map='auto',
            torch_dtype=torch.bfloat16
        )
        
        print(f"  ✓ Policy and reference models loaded")
        print(f"  Policy model parameters: {policy_model.num_parameters():,}")
        
    except Exception as e:
        print(f"  ❌ Error loading policy model: {e}")
        return {'status': 'failed', 'error': str(e)}
    
    # === STEP 3: Initialize PPO Trainer ===
    print("  Initializing PPO trainer...")
    try:
        ppo_config = PPOConfig(**config['ppo_config_dict'])
        
        ppo_trainer = PPOTrainer(
            config=ppo_config,
            model=policy_model,
            ref_model=ref_model,
            tokenizer=tokenizer,
            dataset=dataset
        )
        print("  ✓ PPO trainer initialized")
        
    except Exception as e:
        print(f"  ❌ Error initializing PPO trainer: {e}")
        return {'status': 'failed', 'error': str(e)}
    
    # === STEP 4: Define Reward Generation Function ===
    def get_rewards(prompts, responses):
        """Calculate rewards using the reward model."""
        rewards = []
        
        for prompt, response in zip(prompts, responses):
            # Format text for reward model
            full_text = f"### Human:\n{prompt}\n\n### Assistant:\n{response}"
            
            # Tokenize and get reward
            with torch.no_grad():
                inputs = tokenizer(
                    full_text,
                    return_tensors="pt",
                    truncation=True,
                    max_length=512,
                    padding=True
                ).to('cpu')  # Send to CPU where RM is located
                
                # Get reward score
                outputs = reward_model(**inputs)
                reward = outputs.logits[0, 0].item()  # Extract scalar reward
                rewards.append(reward)
        
        return torch.tensor(rewards).to(device)  # Move rewards to GPU
    
    # === STEP 5: Training Loop ===
    print(f"  Starting PPO training loop...")
    training_stats = []
    
    try:
        for epoch, batch in enumerate(tqdm(ppo_trainer.dataloader, desc=f"PPO {rm_precision}")):
            # Extract prompts
            query_tensors = batch['input_ids']
            
            # Generate responses
            with torch.no_grad():
                response_tensors = ppo_trainer.generate(
                    query_tensors,
                    return_prompt=False,
                    **{'max_new_tokens': config['ppo_config_dict']['response_length'],
                       'do_sample': True,
                       'temperature': 0.7,
                       'pad_token_id': tokenizer.pad_token_id}
                )
            
            # Decode for reward calculation
            prompts = [tokenizer.decode(q, skip_special_tokens=True) for q in query_tensors]
            responses = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors]
            
            # Calculate rewards
            rewards = get_rewards(prompts, responses)
            
            # Perform PPO step
            stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
            
            # Log statistics
            if epoch % 5 == 0:  # Log every 5 steps
                ppo_trainer.log_stats(stats, batch, rewards)
                print(f"    Step {epoch}: Reward mean = {torch.mean(rewards):.3f}, "
                      f"KL = {stats.get('objective/kl', 0):.3f}")
            
            training_stats.append({
                'epoch': epoch,
                'reward_mean': torch.mean(rewards).item(),
                'reward_std': torch.std(rewards).item(),
                'kl': stats.get('objective/kl', 0)
            })
            
            # Early stopping if we've done enough steps
            if epoch >= 20:  # Limit to 20 steps for demonstration
                break
                
    except Exception as e:
        print(f"  ❌ Error during training: {e}")
        return {'status': 'failed', 'error': str(e)}
    
    # === STEP 6: Save Results ===
    print("  Saving trained policy model...")
    output_dir = f'./models/ppo_policy_{rm_precision}'
    os.makedirs(output_dir, exist_ok=True)
    
    try:
        ppo_trainer.save_pretrained(output_dir)
        print(f"  ✓ Policy model saved to {output_dir}")
    except Exception as e:
        print(f"  ⚠️ Warning: Could not save model: {e}")
    
    # Calculate final statistics
    end_time = time.time()
    final_stats = {
        'status': 'success',
        'rm_precision': rm_precision,
        'training_time': end_time - start_time,
        'final_reward_mean': training_stats[-1]['reward_mean'] if training_stats else 0,
        'final_reward_std': training_stats[-1]['reward_std'] if training_stats else 0,
        'final_kl': training_stats[-1]['kl'] if training_stats else 0,
        'total_steps': len(training_stats),
        'output_dir': output_dir,
        'training_history': training_stats
    }
    
    print(f"  ✅ PPO experiment completed in {final_stats['training_time']:.1f} seconds")
    print(f"  Final reward: {final_stats['final_reward_mean']:.3f} ± {final_stats['final_reward_std']:.3f}")
    
    return final_stats

print("✅ PPO pipeline function defined!")


In [None]:
# Experiment Execution Loop - Run All PPO Experiments
print("🎬 Starting comprehensive PPO alignment experiments!")
print(f"Will run {len(rm_precisions_to_run)} experiments with different RM precisions")
print("=" * 70)

# Prepare configuration dictionary
config = {
    'sft_model_path': sft_model_path,
    'rm_model_base_path': rm_model_base_path,
    'ppo_config_dict': ppo_config_dict
}

# Track all experiment results
all_results = {}

for i, rm_precision in enumerate(rm_precisions_to_run):
    print(f"\n{'='*25} Experiment {i+1}/{len(rm_precisions_to_run)} {'='*25}")
    print(f"🎯 Running PPO with {rm_precision.upper()} reward model")
    
    try:
        # Run the PPO experiment
        result = run_ppo_experiment(
            rm_precision=rm_precision,
            config=config,
            tokenizer=tokenizer,
            dataset=prompts_dataset
        )
        
        # Store results
        all_results[rm_precision] = result
        
        if result['status'] == 'success':
            print(f"✅ {rm_precision.upper()} experiment completed successfully!")
            print(f"   Training time: {result['training_time']:.1f}s")
            print(f"   Final reward: {result['final_reward_mean']:.3f}")
            print(f"   Final KL: {result['final_kl']:.3f}")
        else:
            print(f"❌ {rm_precision.upper()} experiment failed: {result['error']}")
            
    except Exception as e:
        print(f"❌ Unexpected error in {rm_precision} experiment: {str(e)}")
        all_results[rm_precision] = {
            'status': 'failed',
            'error': str(e)
        }
    
    finally:
        # CRITICAL: Memory cleanup between experiments
        print(f"🧹 Cleaning up memory after {rm_precision} experiment...")
        
        # Clear any remaining variables
        if 'result' in locals():
            del result
        
        # Force garbage collection and clear CUDA cache
        import gc
        gc.collect()
        torch.cuda.empty_cache()
        
        current_memory = torch.cuda.memory_allocated() / 1024**3
        print(f"   GPU memory after cleanup: {current_memory:.2f} GB")
        
        # Small delay to ensure cleanup
        time.sleep(2)

print("\n" + "="*70)
print("🎉 ALL PPO ALIGNMENT EXPERIMENTS COMPLETED!")

# === COMPREHENSIVE RESULTS SUMMARY ===
print("\n📊 COMPREHENSIVE EXPERIMENT SUMMARY:")
print("-" * 50)

success_count = 0
for precision, result in all_results.items():
    status_emoji = "✅" if result['status'] == 'success' else "❌"
    print(f"\n{status_emoji} {precision.upper()} REWARD MODEL:")
    
    if result['status'] == 'success':
        success_count += 1
        print(f"   Status: SUCCESS")
        print(f"   Training Time: {result['training_time']:.1f} seconds")
        print(f"   Final Reward: {result['final_reward_mean']:.3f} ± {result['final_reward_std']:.3f}")
        print(f"   Final KL Divergence: {result['final_kl']:.3f}")
        print(f"   Total Training Steps: {result['total_steps']}")
        print(f"   Model Saved To: {result['output_dir']}")
    else:
        print(f"   Status: FAILED")
        print(f"   Error: {result['error']}")

print(f"\n🎯 OVERALL SUCCESS RATE: {success_count}/{len(rm_precisions_to_run)} experiments")

# Save results to JSON for further analysis
results_file = './results/ppo_experiment_results.json'
os.makedirs('./results', exist_ok=True)

with open(results_file, 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"📄 Detailed results saved to: {results_file}")
print(f"\n🚀 Your PPO-aligned models are ready for analysis!")
print(f"📁 Policy models saved in: ./models/ppo_policy_*")
