In [None]:
# Run setup from config notebook
%run 0_config_setup.ipynb

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
from tqdm import tqdm
import json
import wandb
from dataclasses import dataclass

set_seed(SEED)

## Load Synthetic Preference Data

In [None]:
print(f"Loading synthetic preference data from {SYNTHETIC_PREFERENCES}...")

preference_data = []
with open(SYNTHETIC_PREFERENCES, 'r', encoding='utf-8') as f:
    for line in f:
        preference_data.append(json.loads(line))

print(f"Loaded {len(preference_data)} preference pairs")

# Split into train/validation
train_size = int(0.9 * len(preference_data))
train_data = preference_data[:train_size]
val_data = preference_data[train_size:]

print(f"Train: {len(train_data)} pairs")
print(f"Validation: {len(val_data)} pairs")

## Preference Dataset

In [None]:
class PreferenceDataset(Dataset):
    """Dataset for pairwise preference data"""
    
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Format: "Source: ... \nTranslation: ..."
        chosen_text = f"Source: {item['source']}\nTranslation: {item['chosen']}"
        rejected_text = f"Source: {item['source']}\nTranslation: {item['rejected']}"
        
        # Tokenize
        chosen_tokens = self.tokenizer(
            chosen_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        rejected_tokens = self.tokenizer(
            rejected_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'chosen_input_ids': chosen_tokens['input_ids'].squeeze(0),
            'chosen_attention_mask': chosen_tokens['attention_mask'].squeeze(0),
            'rejected_input_ids': rejected_tokens['input_ids'].squeeze(0),
            'rejected_attention_mask': rejected_tokens['attention_mask'].squeeze(0),
            'margin': item['margin']  # For analysis
        }

print("PreferenceDataset class defined")

## Reward Model Architecture

In [None]:
class RewardModel(nn.Module):
    """Reward model with base LM + reward head"""
    
    def __init__(self, base_model, hidden_dim=256, head_type='mlp'):
        super().__init__()
        self.base_model = base_model
        self.head_type = head_type
        
        # Get hidden size from base model
        self.hidden_size = base_model.config.hidden_size
        
        # Reward head
        if head_type == 'linear':
            self.reward_head = nn.Linear(self.hidden_size, 1)
        elif head_type == 'mlp':
            self.reward_head = nn.Sequential(
                nn.Linear(self.hidden_size, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim, 1)
            )
        else:
            raise ValueError(f"Unknown head_type: {head_type}")
    
    def forward(self, input_ids, attention_mask):
        # Get base model outputs
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        
        # Get last hidden state
        hidden_states = outputs.hidden_states[-1]  # [batch, seq_len, hidden_size]
        
        # Pool: use last token representation (similar to value head in PPO)
        # Get the last non-padding token for each sequence
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = hidden_states.shape[0]
        pooled = hidden_states[torch.arange(batch_size), sequence_lengths]
        
        # Apply reward head
        reward = self.reward_head(pooled)  # [batch, 1]
        
        return reward.squeeze(-1)  # [batch]

print("RewardModel class defined")

## Load Base Model and Create Reward Model

In [None]:
print(f"Loading base model: {REWARD_BASE_MODEL}...")

# Load tokenizer
rm_tokenizer = AutoTokenizer.from_pretrained(REWARD_BASE_MODEL)
if rm_tokenizer.pad_token is None:
    rm_tokenizer.pad_token = rm_tokenizer.eos_token

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    REWARD_BASE_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

# Freeze base model parameters (optional - can fine-tune last layers)
# For faster training, freeze most layers
for param in base_model.parameters():
    param.requires_grad = False

# Unfreeze last few layers for fine-tuning
num_unfrozen_layers = 4
for layer in base_model.model.layers[-num_unfrozen_layers:]:
    for param in layer.parameters():
        param.requires_grad = True

print(f"✓ Base model loaded (unfrozen last {num_unfrozen_layers} layers)")

# Create reward model
reward_model = RewardModel(
    base_model=base_model,
    hidden_dim=RM_HIDDEN_DIM,
    head_type=RM_HEAD_TYPE
)

print(f"✓ Reward model created with {RM_HEAD_TYPE} head")
print(f"Total parameters: {sum(p.numel() for p in reward_model.parameters()) / 1e6:.2f}M")
print(f"Trainable parameters: {sum(p.numel() for p in reward_model.parameters() if p.requires_grad) / 1e6:.2f}M")

## Create DataLoaders

In [None]:
# Create datasets
train_dataset = PreferenceDataset(train_data, rm_tokenizer, max_length=RM_MAX_LENGTH)
val_dataset = PreferenceDataset(val_data, rm_tokenizer, max_length=RM_MAX_LENGTH)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=RM_BATCH_SIZE,
    shuffle=True,
    num_workers=0  # Set to 0 for Windows compatibility
)

val_loader = DataLoader(
    val_dataset,
    batch_size=RM_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## Training Setup

In [None]:
# Bradley-Terry loss for pairwise preferences
def bradley_terry_loss(chosen_rewards, rejected_rewards):
    """Bradley-Terry model loss: -log(sigmoid(r_chosen - r_rejected))"""
    return -torch.log(torch.sigmoid(chosen_rewards - rejected_rewards)).mean()

# Optimizer
optimizer = torch.optim.AdamW(
    [p for p in reward_model.parameters() if p.requires_grad],
    lr=RM_LEARNING_RATE
)

# Learning rate scheduler
num_training_steps = len(train_loader) * RM_EPOCHS // RM_GRADIENT_ACCUMULATION_STEPS
lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=num_training_steps // 10,
    num_training_steps=num_training_steps
)

# Initialize wandb (optional)
if USE_WANDB:
    wandb.init(
        project=WANDB_PROJECT,
        name="reward-model-coldstart",
        config={
            'learning_rate': RM_LEARNING_RATE,
            'batch_size': RM_BATCH_SIZE,
            'epochs': RM_EPOCHS,
            'base_model': REWARD_BASE_MODEL,
            'head_type': RM_HEAD_TYPE
        }
    )

print("Training setup complete!")
print(f"Total training steps: {num_training_steps}")

## Training Loop

In [None]:
def train_epoch(model, loader, optimizer, scheduler, device, gradient_accumulation_steps=1):
    model.train()
    total_loss = 0
    total_accuracy = 0
    num_batches = 0
    
    optimizer.zero_grad()
    
    pbar = tqdm(loader, desc="Training")
    for step, batch in enumerate(pbar):
        # Move to device
        chosen_input_ids = batch['chosen_input_ids'].to(device)
        chosen_attention_mask = batch['chosen_attention_mask'].to(device)
        rejected_input_ids = batch['rejected_input_ids'].to(device)
        rejected_attention_mask = batch['rejected_attention_mask'].to(device)
        
        # Forward pass
        chosen_rewards = model(chosen_input_ids, chosen_attention_mask)
        rejected_rewards = model(rejected_input_ids, rejected_attention_mask)
        
        # Compute loss
        loss = bradley_terry_loss(chosen_rewards, rejected_rewards)
        loss = loss / gradient_accumulation_steps
        
        # Backward pass
        loss.backward()
        
        # Accuracy: chosen should have higher reward
        accuracy = (chosen_rewards > rejected_rewards).float().mean()
        
        total_loss += loss.item() * gradient_accumulation_steps
        total_accuracy += accuracy.item()
        num_batches += 1
        
        # Update weights
        if (step + 1) % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f"{total_loss / num_batches:.4f}",
            'acc': f"{total_accuracy / num_batches:.4f}",
            'lr': f"{scheduler.get_last_lr()[0]:.2e}"
        })
    
    return total_loss / num_batches, total_accuracy / num_batches


def validate(model, loader, device):
    model.eval()
    total_loss = 0
    total_accuracy = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            chosen_input_ids = batch['chosen_input_ids'].to(device)
            chosen_attention_mask = batch['chosen_attention_mask'].to(device)
            rejected_input_ids = batch['rejected_input_ids'].to(device)
            rejected_attention_mask = batch['rejected_attention_mask'].to(device)
            
            chosen_rewards = model(chosen_input_ids, chosen_attention_mask)
            rejected_rewards = model(rejected_input_ids, rejected_attention_mask)
            
            loss = bradley_terry_loss(chosen_rewards, rejected_rewards)
            accuracy = (chosen_rewards > rejected_rewards).float().mean()
            
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            num_batches += 1
    
    return total_loss / num_batches, total_accuracy / num_batches

print("Training functions defined")

In [None]:
# Train the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
reward_model = reward_model.to(device)

print("Starting training...\n")
best_val_accuracy = 0

for epoch in range(RM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{RM_EPOCHS}")
    print("=" * 80)
    
    # Train
    train_loss, train_acc = train_epoch(
        reward_model,
        train_loader,
        optimizer,
        lr_scheduler,
        device,
        gradient_accumulation_steps=RM_GRADIENT_ACCUMULATION_STEPS
    )
    
    # Validate
    val_loss, val_acc = validate(reward_model, val_loader, device)
    
    print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    
    # Log to wandb
    if USE_WANDB:
        wandb.log({
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_accuracy': train_acc,
            'val_loss': val_loss,
            'val_accuracy': val_acc
        })
    
    # Save best model
    if val_acc > best_val_accuracy:
        best_val_accuracy = val_acc
        print(f"\n✓ New best validation accuracy: {best_val_accuracy:.4f}")
        print(f"Saving model to {REWARD_MODEL_COLD_START}...")
        
        # Save model
        REWARD_MODEL_COLD_START.mkdir(exist_ok=True, parents=True)
        torch.save({
            'model_state_dict': reward_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'val_accuracy': val_acc,
            'config': {
                'base_model': REWARD_BASE_MODEL,
                'head_type': RM_HEAD_TYPE,
                'hidden_dim': RM_HIDDEN_DIM
            }
        }, REWARD_MODEL_COLD_START / "reward_model.pt")
        
        # Save tokenizer
        rm_tokenizer.save_pretrained(REWARD_MODEL_COLD_START)

print(f"\n{'=' * 80}")
print("Training complete!")
print(f"Best validation accuracy: {best_val_accuracy:.4f}")

if USE_WANDB:
    wandb.finish()

## Test Reward Model

In [None]:
# Test the trained reward model
reward_model.eval()

print("Testing reward model on sample translations...\n")
print("=" * 80)

# Get some test examples
test_samples = random.sample(val_data, min(5, len(val_data)))

for i, sample in enumerate(test_samples, 1):
    print(f"\nExample {i}:")
    print(f"Source: {sample['source'][:100]}...")
    
    # Prepare inputs
    chosen_text = f"Source: {sample['source']}\nTranslation: {sample['chosen']}"
    rejected_text = f"Source: {sample['source']}\nTranslation: {sample['rejected']}"
    
    chosen_tokens = rm_tokenizer(
        chosen_text,
        max_length=RM_MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)
    
    rejected_tokens = rm_tokenizer(
        rejected_text,
        max_length=RM_MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)
    
    # Get rewards
    with torch.no_grad():
        chosen_reward = reward_model(
            chosen_tokens['input_ids'],
            chosen_tokens['attention_mask']
        ).item()
        
        rejected_reward = reward_model(
            rejected_tokens['input_ids'],
            rejected_tokens['attention_mask']
        ).item()
    
    print(f"\nChosen translation: {sample['chosen'][:100]}...")
    print(f"Chosen reward: {chosen_reward:.4f} (original score: {sample['chosen_score']:.4f})")
    
    print(f"\nRejected translation: {sample['rejected'][:100]}...")
    print(f"Rejected reward: {rejected_reward:.4f} (original score: {sample['rejected_score']:.4f})")
    
    print(f"\nReward margin: {chosen_reward - rejected_reward:.4f}")
    print(f"Correct preference: {'✓' if chosen_reward > rejected_reward else '✗'}")
    print("=" * 80)

## Next Step

Proceed to **notebook 3** to run PPO optimization using this trained reward model.