In [None]:
# Run setup from config notebook
%run 0_config_setup.ipynb

In [None]:
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
from datasets import Dataset
from tqdm import tqdm
import wandb
import json

# Import LoRA/PEFT for safer training
from peft import LoraConfig, get_peft_model, TaskType

set_seed(SEED)

print("üöÄ PPO Optimization - HPC Optimized")
print(f"   LoRA enabled: {USE_LORA}")
print(f"   KL Penalty: {KL_PENALTY_COEF}")

## Load Models

In [None]:
print("Loading models...\n")

# Model name (HuggingFace Hub)
model_name = "ModelSpace/GemmaX2-28-9B-v0.1"

# Configure memory allocation for multi-GPU
if NUM_GPUS > 1:
    max_memory = {i: GPU_MEMORY_PER_DEVICE for i in range(NUM_GPUS)}
    max_memory["cpu"] = "32GB"
else:
    max_memory = None

# 1. Load SFT model (policy model for PPO)
print(f"Loading SFT model: {model_name}")

policy_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if policy_tokenizer.pad_token is None:
    policy_tokenizer.pad_token = policy_tokenizer.eos_token
policy_tokenizer.padding_side = "left"

# ===========================
# LoRA CONFIGURATION (RECOMMENDED)
# ===========================
if USE_LORA:
    print("üõ°Ô∏è Using LoRA - Original SFT weights will be PRESERVED")
    
    # Configure LoRA
    lora_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        target_modules=LORA_TARGET_MODULES,
        lora_dropout=LORA_DROPOUT,
        bias="none",
        task_type=TaskType.CAUSAL_LM
    )
    
    # Load base model first
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
        max_memory=max_memory,
        attn_implementation=ATTN_IMPLEMENTATION if USE_FLASH_ATTENTION else None
    )
    
    # Apply LoRA adapters (original weights stay FROZEN)
    base_model = get_peft_model(base_model, lora_config)
    base_model.print_trainable_parameters()
    
    # Wrap with value head for PPO
    policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
        base_model,
        torch_dtype=torch.bfloat16,
        is_trainable=True
    )
    print("‚úì SFT model loaded with LoRA adapters")
    
else:
    print("‚ö†Ô∏è Full fine-tuning mode - SFT weights WILL be modified")
    
    # Load with value head for PPO (full fine-tuning)
    policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
        max_memory=max_memory,
        attn_implementation=ATTN_IMPLEMENTATION if USE_FLASH_ATTENTION else None
    )
    print("‚úì SFT model loaded (full fine-tuning)")

# 2. Load reference model (for KL penalty - always frozen)
print(f"\nLoading reference model (frozen) for KL computation...")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    max_memory=max_memory
)

# Freeze reference model completely
for param in ref_model.parameters():
    param.requires_grad = False
ref_model.eval()
print("‚úì Reference model loaded (frozen)")

In [None]:
# 3. Load reward model
print(f"\nLoading reward model from {REWARD_MODEL_COLD_START}...")

# Load tokenizer
rm_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL_COLD_START)

# Load base model
rm_base_model = AutoModelForCausalLM.from_pretrained(
    REWARD_BASE_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

# Recreate reward model structure (from notebook 2)
from torch import nn

class RewardModel(nn.Module):
    def __init__(self, base_model, hidden_dim=256, head_type='mlp'):
        super().__init__()
        self.base_model = base_model
        self.head_type = head_type
        self.hidden_size = base_model.config.hidden_size
        
        if head_type == 'linear':
            self.reward_head = nn.Linear(self.hidden_size, 1)
        elif head_type == 'mlp':
            self.reward_head = nn.Sequential(
                nn.Linear(self.hidden_size, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim, 1)
            )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        hidden_states = outputs.hidden_states[-1]
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = hidden_states.shape[0]
        pooled = hidden_states[torch.arange(batch_size), sequence_lengths]
        reward = self.reward_head(pooled)
        return reward.squeeze(-1)

# Create and load weights
reward_model = RewardModel(
    base_model=rm_base_model,
    hidden_dim=RM_HIDDEN_DIM,
    head_type=RM_HEAD_TYPE
)

checkpoint = torch.load(
    REWARD_MODEL_COLD_START / "reward_model.pt",
    map_location='cpu'
)
reward_model.load_state_dict(checkpoint['model_state_dict'])
reward_model.eval()

# Freeze reward model
for param in reward_model.parameters():
    param.requires_grad = False

print("‚úì Reward model loaded (frozen)")
print(f"\nAll models loaded successfully!")

## Prepare Training Data

In [None]:
# Load training prompts (no parallel corpus needed)
print("Loading training prompts...")

try:
    all_data = load_test_prompts(TEST_PROMPTS)
    print(f"Loaded {len(all_data)} test prompts")
except:
    # Create sample prompts if file doesn't exist
    all_data = [
        {"source": "Hello, how are you?", "source_lang": "en"},
        {"source": "Good morning.", "source_lang": "en"},
        {"source": "Thank you very much.", "source_lang": "en"},
        {"source": "Bonjour, comment allez-vous?", "source_lang": "fr"},
        {"source": "Merci beaucoup.", "source_lang": "fr"},
    ] * 1000  # Replicate for training
    print(f"Created {len(all_data)} sample prompts")

# Create dataset of prompts for PPO
prompts = []
for item in all_data[:5000]:  # Adjust size based on resources
    prompt = format_translation_prompt(item['source'], item['source_lang'])
    prompts.append({
        'query': prompt,
        'source': item['source'],
        'source_lang': item['source_lang']
    })

# Create HuggingFace dataset
dataset = Dataset.from_list(prompts)

print(f"Loaded {len(dataset)} training prompts")

## PPO Configuration

In [None]:
# PPO configuration - HPC Optimized with Safety Measures
ppo_config = PPOConfig(
    model_name=model_name,
    learning_rate=PPO_LEARNING_RATE,
    batch_size=PPO_BATCH_SIZE,
    mini_batch_size=PPO_MINI_BATCH_SIZE,
    gradient_accumulation_steps=PPO_GRADIENT_ACCUMULATION_STEPS,
    ppo_epochs=PPO_EPOCHS,
    
    # KL penalty to stay close to reference model (INCREASED for safety)
    init_kl_coef=KL_PENALTY_COEF,
    target_kl=KL_TARGET,
    
    # PPO clipping
    cliprange=CLIP_RANGE,
    cliprange_value=VALUE_CLIP_RANGE,
    
    # GAE parameters
    vf_coef=0.1,
    gamma=GAMMA,
    lam=GAE_LAMBDA,
    
    # Other settings
    seed=SEED,
    log_with="wandb" if USE_WANDB else None,
    tracker_project_name=WANDB_PROJECT,
    tracker_kwargs={"name": f"ppo-coldstart-{'lora' if USE_LORA else 'full'}"},
    
    # Optimization settings for HPC
    optimize_cuda_cache=True,
)

print("PPO Configuration (HPC Optimized):")
print(f"  Learning rate: {ppo_config.learning_rate}")
print(f"  Batch size: {ppo_config.batch_size}")
print(f"  Mini-batch size: {ppo_config.mini_batch_size}")
print(f"  KL penalty (init): {ppo_config.init_kl_coef}")
print(f"  KL target: {ppo_config.target_kl}")
print(f"  Clip range: {ppo_config.cliprange}")
print(f"  LoRA: {USE_LORA}")

## Reward Function

In [None]:
def compute_reward(source_texts, translations):
    """
    Compute rewards for generated translations using the reward model.
    
    Args:
        source_texts: List of source texts
        translations: List of generated translations
    
    Returns:
        List of reward scores (tensors)
    """
    rewards = []
    
    for source, translation in zip(source_texts, translations):
        # Format for reward model
        text = f"Source: {source}\nTranslation: {translation}"
        
        # Tokenize
        inputs = rm_tokenizer(
            text,
            max_length=RM_MAX_LENGTH,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(reward_model.base_model.device)
        
        # Get reward
        with torch.no_grad():
            reward = reward_model(
                inputs['input_ids'],
                inputs['attention_mask']
            )
        
        rewards.append(reward.cpu())
    
    return rewards

print("Reward function defined")

## Initialize PPO Trainer

In [None]:
# Initialize PPO trainer
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=policy_model,
    ref_model=ref_model,
    tokenizer=policy_tokenizer,
    dataset=dataset,
    data_collator=None,
)

print("PPO Trainer initialized!")

## PPO Training Loop

In [None]:
# Generation settings
generation_kwargs = {
    "max_new_tokens": PPO_MAX_NEW_TOKENS,
    "temperature": PPO_TEMPERATURE,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
    "pad_token_id": policy_tokenizer.pad_token_id,
    "eos_token_id": policy_tokenizer.eos_token_id,
}

print("üöÄ Starting PPO training (HPC Optimized with Safety Measures)...\n")
print(f"Total steps: {PPO_STEPS}")
print(f"Max KL threshold: {MAX_KL_THRESHOLD} (will stop if exceeded)")
print("=" * 80)

# Track metrics for early stopping
best_reward = float('-inf')
initial_kl = None
rewards_history = []
kl_history = []

# Training loop with safety measures
for step, batch in enumerate(tqdm(ppo_trainer.dataloader, total=PPO_STEPS)):
    if step >= PPO_STEPS:
        break
    
    # Get queries (prompts)
    query_tensors = batch['input_ids']
    
    # Generate responses (translations)
    response_tensors = ppo_trainer.generate(
        query_tensors,
        return_prompt=False,
        **generation_kwargs
    )
    
    # Decode responses
    batch_texts = policy_tokenizer.batch_decode(query_tensors, skip_special_tokens=True)
    response_texts = policy_tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
    
    # Extract source texts (for reward computation)
    source_texts = []
    for text in batch_texts:
        if "English text to Arabic:" in text:
            source = text.split("English text to Arabic:")[1].split("\n\nArabic translation:")[0].strip()
        elif "French text to Arabic:" in text:
            source = text.split("French text to Arabic:")[1].split("\n\nArabic translation:")[0].strip()
        else:
            source = text
        source_texts.append(source)
    
    # Compute rewards
    rewards = compute_reward(source_texts, response_texts)
    
    # Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    
    # Track metrics
    current_kl = stats['objective/kl']
    current_reward = torch.tensor(rewards).mean().item()
    rewards_history.append(current_reward)
    kl_history.append(current_kl)
    
    if initial_kl is None:
        initial_kl = current_kl
    
    # ===========================
    # SAFETY CHECK: KL Divergence
    # ===========================
    if current_kl > MAX_KL_THRESHOLD:
        print(f"\n‚ö†Ô∏è SAFETY STOP: KL divergence ({current_kl:.4f}) exceeded threshold ({MAX_KL_THRESHOLD})")
        print("   Model is drifting too far from SFT. Stopping to preserve quality.")
        break
    
    # Log statistics
    if step % 10 == 0:
        ppo_trainer.log_stats(
            stats,
            batch,
            rewards,
            columns_to_log=["query", "response"]
        )
    
    # Print progress
    if step % 50 == 0:
        mean_reward = current_reward
        print(f"\nStep {step}:")
        print(f"  Mean reward: {mean_reward:.4f}")
        print(f"  Mean KL: {current_kl:.4f} (threshold: {MAX_KL_THRESHOLD})")
        print(f"  Policy loss: {stats['ppo/loss/policy']:.4f}")
        print(f"  Value loss: {stats['ppo/loss/value']:.4f}")
        
        # KL health indicator
        kl_ratio = current_kl / MAX_KL_THRESHOLD
        if kl_ratio < 0.5:
            kl_status = "‚úÖ Healthy"
        elif kl_ratio < 0.8:
            kl_status = "‚ö†Ô∏è Moderate"
        else:
            kl_status = "üî¥ High - approaching limit"
        print(f"  KL status: {kl_status}")
        
        # Show sample
        print(f"\n  Sample translation:")
        print(f"  Source: {source_texts[0][:100]}...")
        print(f"  Generated: {response_texts[0][:100]}...")
        print(f"  Reward: {rewards[0].item():.4f}")
    
    # Save checkpoint periodically
    if step > 0 and step % CHECKPOINT_EVERY_N_STEPS == 0:
        checkpoint_path = PPO_MODEL_COLD_START / f"checkpoint-{step}"
        checkpoint_path.mkdir(exist_ok=True, parents=True)
        
        if USE_LORA:
            # Save only LoRA adapters (small file)
            policy_model.pretrained_model.save_pretrained(checkpoint_path)
        else:
            ppo_trainer.model.save_pretrained(checkpoint_path)
        
        policy_tokenizer.save_pretrained(checkpoint_path)
        print(f"\n‚úì Checkpoint saved to {checkpoint_path}")
        
        # Save training metrics
        metrics = {
            'step': step,
            'mean_reward': float(np.mean(rewards_history[-100:])),
            'mean_kl': float(np.mean(kl_history[-100:])),
            'max_kl_seen': float(max(kl_history))
        }
        with open(checkpoint_path / "metrics.json", 'w') as f:
            json.dump(metrics, f, indent=2)

print("\n" + "=" * 80)
print("‚úÖ PPO training complete!")
print(f"   Final mean reward: {np.mean(rewards_history[-50:]):.4f}")
print(f"   Final mean KL: {np.mean(kl_history[-50:]):.4f}")
print(f"   Max KL observed: {max(kl_history):.4f}")

## Save Final Model

In [None]:
# Save final optimized model
print(f"Saving final model to {PPO_MODEL_COLD_START}...")

PPO_MODEL_COLD_START.mkdir(exist_ok=True, parents=True)

if USE_LORA:
    # Save LoRA adapters only (original SFT weights preserved)
    policy_model.pretrained_model.save_pretrained(PPO_MODEL_COLD_START)
    print("‚úì LoRA adapters saved (original SFT weights preserved)")
else:
    # Save full model
    ppo_trainer.model.save_pretrained(PPO_MODEL_COLD_START)
    print("‚úì Full model saved")

policy_tokenizer.save_pretrained(PPO_MODEL_COLD_START)

# Save training info
training_info = {
    'base_model': model_name,
    'reward_model': str(REWARD_MODEL_COLD_START),
    'ppo_steps': PPO_STEPS,
    'steps_completed': step + 1,
    'ppo_config': {
        'learning_rate': PPO_LEARNING_RATE,
        'kl_penalty': KL_PENALTY_COEF,
        'max_kl_threshold': MAX_KL_THRESHOLD,
        'clip_range': CLIP_RANGE
    },
    'lora_config': {
        'enabled': USE_LORA,
        'r': LORA_R if USE_LORA else None,
        'alpha': LORA_ALPHA if USE_LORA else None,
        'target_modules': LORA_TARGET_MODULES if USE_LORA else None
    },
    'final_metrics': {
        'mean_reward': float(np.mean(rewards_history[-50:])) if rewards_history else 0,
        'mean_kl': float(np.mean(kl_history[-50:])) if kl_history else 0,
        'max_kl': float(max(kl_history)) if kl_history else 0
    },
    'stage': 'cold_start'
}

with open(PPO_MODEL_COLD_START / "training_info.json", 'w') as f:
    json.dump(training_info, f, indent=2)

print("‚úì Model saved successfully!")
print(f"\nPath: {PPO_MODEL_COLD_START}")
print(f"LoRA adapters: {USE_LORA}")
print(f"Original SFT weights preserved: {USE_LORA}")

## Test Optimized Model

In [None]:
# Test the optimized model
print("Testing optimized model...\n")
print("=" * 80)

test_samples = [
    {"text": "Hello, how are you today?", "lang": "en"},
    {"text": "The weather is beautiful this morning.", "lang": "en"},
    {"text": "Bonjour, comment allez-vous?", "lang": "fr"},
]

ppo_trainer.model.eval()

for i, sample in enumerate(test_samples, 1):
    prompt = format_translation_prompt(sample['text'], sample['lang'])
    
    inputs = policy_tokenizer(prompt, return_tensors="pt").to(ppo_trainer.model.device)
    
    with torch.no_grad():
        outputs = ppo_trainer.model.generate(
            **inputs,
            max_new_tokens=128,
            temperature=0.8,
            do_sample=True,
            pad_token_id=policy_tokenizer.pad_token_id
        )
    
    full_text = policy_tokenizer.decode(outputs[0], skip_special_tokens=True)
    translation = full_text.split("Arabic translation:")[-1].strip()
    
    # Compute reward
    reward = compute_reward([sample['text']], [translation])[0].item()
    
    print(f"\nExample {i}:")
    print(f"Source ({sample['lang']}): {sample['text']}")
    print(f"Translation: {translation}")
    print(f"Reward: {reward:.4f}")
    print("=" * 80)

if USE_WANDB:
    wandb.finish()

## Next Step

Proceed to **notebook 4** for inference and user feedback collection.