In [None]:
# Run setup from config notebook
%run 0_config_setup.ipynb

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
from datasets import Dataset
from tqdm import tqdm
import wandb
import json

set_seed(SEED)

## Load Models

In [None]:
print("Loading models...\n")

# 1. Load SFT model (policy model for PPO)
print(f"Loading SFT model from {SFT_MODEL_PATH}...")
policy_tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL_PATH)
if policy_tokenizer.pad_token is None:
    policy_tokenizer.pad_token = policy_tokenizer.eos_token

# Load with value head for PPO
policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    SFT_MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    peft_config=None  # Add LoRA config if needed for memory efficiency
)
print("✓ SFT model loaded")

# 2. Load reference model (for KL penalty)
print(f"\nLoading reference model (frozen SFT) from {SFT_MODEL_PATH}...")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    SFT_MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)
# Freeze reference model
for param in ref_model.parameters():
    param.requires_grad = False
print("✓ Reference model loaded (frozen)")

In [None]:
# 3. Load reward model
print(f"\nLoading reward model from {REWARD_MODEL_COLD_START}...")

# Load tokenizer
rm_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL_COLD_START)

# Load base model
rm_base_model = AutoModelForCausalLM.from_pretrained(
    REWARD_BASE_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

# Recreate reward model structure (from notebook 2)
from torch import nn

class RewardModel(nn.Module):
    def __init__(self, base_model, hidden_dim=256, head_type='mlp'):
        super().__init__()
        self.base_model = base_model
        self.head_type = head_type
        self.hidden_size = base_model.config.hidden_size
        
        if head_type == 'linear':
            self.reward_head = nn.Linear(self.hidden_size, 1)
        elif head_type == 'mlp':
            self.reward_head = nn.Sequential(
                nn.Linear(self.hidden_size, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim, 1)
            )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        hidden_states = outputs.hidden_states[-1]
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = hidden_states.shape[0]
        pooled = hidden_states[torch.arange(batch_size), sequence_lengths]
        reward = self.reward_head(pooled)
        return reward.squeeze(-1)

# Create and load weights
reward_model = RewardModel(
    base_model=rm_base_model,
    hidden_dim=RM_HIDDEN_DIM,
    head_type=RM_HEAD_TYPE
)

checkpoint = torch.load(
    REWARD_MODEL_COLD_START / "reward_model.pt",
    map_location='cpu'
)
reward_model.load_state_dict(checkpoint['model_state_dict'])
reward_model.eval()

# Freeze reward model
for param in reward_model.parameters():
    param.requires_grad = False

print("✓ Reward model loaded (frozen)")
print(f"\nAll models loaded successfully!")

## Prepare Training Data

In [None]:
# Load training prompts (no parallel corpus needed)
print("Loading training prompts...")

try:
    all_data = load_test_prompts(TEST_PROMPTS)
    print(f"Loaded {len(all_data)} test prompts")
except:
    # Create sample prompts if file doesn't exist
    all_data = [
        {"source": "Hello, how are you?", "source_lang": "en"},
        {"source": "Good morning.", "source_lang": "en"},
        {"source": "Thank you very much.", "source_lang": "en"},
        {"source": "Bonjour, comment allez-vous?", "source_lang": "fr"},
        {"source": "Merci beaucoup.", "source_lang": "fr"},
    ] * 1000  # Replicate for training
    print(f"Created {len(all_data)} sample prompts")

# Create dataset of prompts for PPO
prompts = []
for item in all_data[:5000]:  # Adjust size based on resources
    prompt = format_translation_prompt(item['source'], item['source_lang'])
    prompts.append({
        'query': prompt,
        'source': item['source'],
        'source_lang': item['source_lang']
    })

# Create HuggingFace dataset
dataset = Dataset.from_list(prompts)

print(f"Loaded {len(dataset)} training prompts")

## PPO Configuration

In [None]:
# PPO configuration
ppo_config = PPOConfig(
    model_name=SFT_MODEL_PATH,
    learning_rate=PPO_LEARNING_RATE,
    batch_size=PPO_BATCH_SIZE,
    mini_batch_size=PPO_MINI_BATCH_SIZE,
    gradient_accumulation_steps=PPO_GRADIENT_ACCUMULATION_STEPS,
    ppo_epochs=PPO_EPOCHS,
    
    # KL penalty to stay close to reference model
    init_kl_coef=KL_PENALTY_COEF,
    target_kl=0.1,
    
    # PPO clipping
    cliprange=CLIP_RANGE,
    cliprange_value=VALUE_CLIP_RANGE,
    
    # GAE parameters
    vf_coef=0.1,
    
    # Other settings
    seed=SEED,
    log_with="wandb" if USE_WANDB else None,
    tracker_project_name=WANDB_PROJECT,
    tracker_kwargs={"name": "ppo-coldstart"},
)

print("PPO Configuration:")
print(f"  Learning rate: {ppo_config.learning_rate}")
print(f"  Batch size: {ppo_config.batch_size}")
print(f"  KL penalty: {ppo_config.init_kl_coef}")
print(f"  Clip range: {ppo_config.cliprange}")

## Reward Function

In [None]:
def compute_reward(source_texts, translations):
    """
    Compute rewards for generated translations using the reward model.
    
    Args:
        source_texts: List of source texts
        translations: List of generated translations
    
    Returns:
        List of reward scores (tensors)
    """
    rewards = []
    
    for source, translation in zip(source_texts, translations):
        # Format for reward model
        text = f"Source: {source}\nTranslation: {translation}"
        
        # Tokenize
        inputs = rm_tokenizer(
            text,
            max_length=RM_MAX_LENGTH,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(reward_model.base_model.device)
        
        # Get reward
        with torch.no_grad():
            reward = reward_model(
                inputs['input_ids'],
                inputs['attention_mask']
            )
        
        rewards.append(reward.cpu())
    
    return rewards

print("Reward function defined")

## Initialize PPO Trainer

In [None]:
# Initialize PPO trainer
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=policy_model,
    ref_model=ref_model,
    tokenizer=policy_tokenizer,
    dataset=dataset,
    data_collator=None,
)

print("PPO Trainer initialized!")

## PPO Training Loop

In [None]:
# Generation settings
generation_kwargs = {
    "max_new_tokens": PPO_MAX_NEW_TOKENS,
    "temperature": PPO_TEMPERATURE,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
    "pad_token_id": policy_tokenizer.pad_token_id,
    "eos_token_id": policy_tokenizer.eos_token_id,
}

print("Starting PPO training...\n")
print(f"Total steps: {PPO_STEPS}")
print("=" * 80)

# Training loop
for step, batch in enumerate(tqdm(ppo_trainer.dataloader, total=PPO_STEPS)):
    if step >= PPO_STEPS:
        break
    
    # Get queries (prompts)
    query_tensors = batch['input_ids']
    
    # Generate responses (translations)
    response_tensors = ppo_trainer.generate(
        query_tensors,
        return_prompt=False,
        **generation_kwargs
    )
    
    # Decode responses
    batch_texts = policy_tokenizer.batch_decode(query_tensors, skip_special_tokens=True)
    response_texts = policy_tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
    
    # Extract source texts (for reward computation)
    source_texts = []
    for text in batch_texts:
        # Extract source from prompt
        if "English text to Arabic:" in text:
            source = text.split("English text to Arabic:")[1].split("\n\nArabic translation:")[0].strip()
        elif "French text to Arabic:" in text:
            source = text.split("French text to Arabic:")[1].split("\n\nArabic translation:")[0].strip()
        else:
            source = text
        source_texts.append(source)
    
    # Compute rewards
    rewards = compute_reward(source_texts, response_texts)
    
    # Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    
    # Log statistics
    if step % 10 == 0:
        ppo_trainer.log_stats(
            stats,
            batch,
            rewards,
            columns_to_log=["query", "response"]
        )
    
    # Print progress
    if step % 50 == 0:
        mean_reward = torch.tensor(rewards).mean().item()
        print(f"\nStep {step}:")
        print(f"  Mean reward: {mean_reward:.4f}")
        print(f"  Mean KL: {stats['objective/kl']:.4f}")
        print(f"  Policy loss: {stats['ppo/loss/policy']:.4f}")
        print(f"  Value loss: {stats['ppo/loss/value']:.4f}")
        
        # Show sample
        print(f"\n  Sample translation:")
        print(f"  Source: {source_texts[0][:100]}...")
        print(f"  Generated: {response_texts[0][:100]}...")
        print(f"  Reward: {rewards[0].item():.4f}")
    
    # Save checkpoint periodically
    if step > 0 and step % 200 == 0:
        checkpoint_path = PPO_MODEL_COLD_START / f"checkpoint-{step}"
        checkpoint_path.mkdir(exist_ok=True, parents=True)
        ppo_trainer.model.save_pretrained(checkpoint_path)
        policy_tokenizer.save_pretrained(checkpoint_path)
        print(f"\n✓ Checkpoint saved to {checkpoint_path}")

print("\n" + "=" * 80)
print("PPO training complete!")

## Save Final Model

In [None]:
# Save final optimized model
print(f"Saving final model to {PPO_MODEL_COLD_START}...")

PPO_MODEL_COLD_START.mkdir(exist_ok=True, parents=True)
ppo_trainer.model.save_pretrained(PPO_MODEL_COLD_START)
policy_tokenizer.save_pretrained(PPO_MODEL_COLD_START)

# Save training info
training_info = {
    'base_model': SFT_MODEL_PATH,
    'reward_model': str(REWARD_MODEL_COLD_START),
    'ppo_steps': PPO_STEPS,
    'ppo_config': {
        'learning_rate': PPO_LEARNING_RATE,
        'kl_penalty': KL_PENALTY_COEF,
        'clip_range': CLIP_RANGE
    },
    'stage': 'cold_start'
}

with open(PPO_MODEL_COLD_START / "training_info.json", 'w') as f:
    json.dump(training_info, f, indent=2)

print("✓ Model saved successfully!")
print(f"\nPath: {PPO_MODEL_COLD_START}")

## Test Optimized Model

In [None]:
# Test the optimized model
print("Testing optimized model...\n")
print("=" * 80)

test_samples = [
    {"text": "Hello, how are you today?", "lang": "en"},
    {"text": "The weather is beautiful this morning.", "lang": "en"},
    {"text": "Bonjour, comment allez-vous?", "lang": "fr"},
]

ppo_trainer.model.eval()

for i, sample in enumerate(test_samples, 1):
    prompt = format_translation_prompt(sample['text'], sample['lang'])
    
    inputs = policy_tokenizer(prompt, return_tensors="pt").to(ppo_trainer.model.device)
    
    with torch.no_grad():
        outputs = ppo_trainer.model.generate(
            **inputs,
            max_new_tokens=128,
            temperature=0.8,
            do_sample=True,
            pad_token_id=policy_tokenizer.pad_token_id
        )
    
    full_text = policy_tokenizer.decode(outputs[0], skip_special_tokens=True)
    translation = full_text.split("Arabic translation:")[-1].strip()
    
    # Compute reward
    reward = compute_reward([sample['text']], [translation])[0].item()
    
    print(f"\nExample {i}:")
    print(f"Source ({sample['lang']}): {sample['text']}")
    print(f"Translation: {translation}")
    print(f"Reward: {reward:.4f}")
    print("=" * 80)

if USE_WANDB:
    wandb.finish()

## Next Step

Proceed to **notebook 4** for inference and user feedback collection.