Reward Model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from torch.utils.data import DataLoader, Dataset
import copy
import json
from typing import List, Dict, Any, Optional
import math
import traceback


class PairwisePreferenceDataset(Dataset):
    """
    Expects list of dicts with keys: 'prompt', 'chosen', 'rejected'
    """
    def __init__(self, file_path):
        self.data = []
        with open(file_path, "r") as f:
            for line in f:
                if line.strip():                     # skip empty lines
                    self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch: List[Dict[str, str]]) -> Dict[str, Any]:
    """
    Returns a dict with lists; trainer will tokenize using the tokenizer so we keep raw strings here.
    """
    prompts = [item['prompt'] for item in batch]
    chosen = [p + "\n" + item['chosen'] for p, item in zip(prompts, batch)]
    rejected = [p + "\n" + item['rejected'] for p, item in zip(prompts, batch)]
    return {"prompt": prompts, "chosen": chosen, "rejected": rejected}

In [3]:
dataset = PairwisePreferenceDataset('financial_rewards_500.jsonl')
loader = DataLoader(dataset, batch_size=5, shuffle=True, collate_fn=collate_fn)

In [4]:

POLICY_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" # Renamed MODEL_ID for clarity
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-5
MAX_LENGTH = 256
MAX_GRAD_NORM = 1.0
tokenizer = AutoTokenizer.from_pretrained(POLICY_MODEL_ID)            # Policy tokenizer
policy = AutoModelForCausalLM.from_pretrained(POLICY_MODEL_ID)       # Policy model
ref = AutoModelForCausalLM.from_pretrained(POLICY_MODEL_ID)
policy_optimizer = optim.AdamW(policy.parameters(), lr=LEARNING_RATE)

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

In [5]:
policy.to(DEVICE)
ref.to(DEVICE)
reward_model.to(DEVICE)

RewardModel(
  (model): DebertaV2ForSequenceClassification(
    (deberta): DebertaV2Model(
      (embeddings): DebertaV2Embeddings(
        (word_embeddings): Embedding(128100, 768, padding_idx=0)
        (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): DebertaV2Encoder(
        (layer): ModuleList(
          (0-5): 6 x DebertaV2Layer(
            (attention): DebertaV2Attention(
              (self): DisentangledSelfAttention(
                (query_proj): Linear(in_features=768, out_features=768, bias=True)
                (key_proj): Linear(in_features=768, out_features=768, bias=True)
                (value_proj): Linear(in_features=768, out_features=768, bias=True)
                (pos_dropout): Dropout(p=0.1, inplace=False)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): DebertaV2SelfOutput(
                (dense): Linear(in_features=768, o

In [6]:
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

# Ensure the reward_model_tokenizer (from cell LnBx9XggB4D4) also has padding set
if reward_model_tokenizer.pad_token is None:
    reward_model_tokenizer.pad_token = reward_model_tokenizer.eos_token
reward_model_tokenizer.padding_side = "left"


policy.config.pad_token_id = tokenizer.pad_token_id
policy.config.eos_token_id = tokenizer.pad_token_id
ref.config.pad_token_id = tokenizer.pad_token_id
ref.config.eos_token_id = tokenizer.pad_token_id
print("Model token IDs aligned.")

Model token IDs aligned.


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import copy

# Constants (Adjust based on your hardware)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_GRAD_NORM = 1.0
LEARNING_RATE = 1e-6

def get_batch_logps(model, sequences, attention_mask, prompt_lens):
    """
    Calculates the log probabilities of the generated text, masking out the prompt.
    """
    # 1. Forward pass (Input Mask: Attend to Prompt + Response, ignore Pad)
    outputs = model(sequences, attention_mask=attention_mask)

    # 2. Shift logits and labels (Next-token prediction)
    # Logits match the output of the model (predictions)
    # Labels match the input tokens shifted by 1 (ground truth)
    logits = outputs.logits[:, :-1, :]
    labels = sequences[:, 1:]

    # 3. Get log probs of the actual tokens chosen
    # Shape: [Batch, Seq_Len-1]
    per_token_logps = torch.gather(logits.log_softmax(-1), 2, labels.unsqueeze(2)).squeeze(2)

    # 4. Create the Loss Mask (Grading Mask: Grade only Response, ignore Prompt & Pad)
    loss_mask = attention_mask[:, 1:].clone()

    for i, start_idx in enumerate(prompt_lens):
        # Zero out the prompt tokens (start_idx - 1 because of shifting)
        loss_mask[i, :start_idx - 1] = 0.0

    # Return per-token log probs and the mask (for averaging later)
    return per_token_logps, loss_mask

class CustomGRPOTrainer:
    def __init__(self, policy_model, ref_model, reward_model, tokenizer, policy_optimizer, debug=True):
        self.policy_model = policy_model
        self.ref_model = ref_model
        self.reward_model, self.reward_tokenizer = reward_model # Tuple unpack
        self.tokenizer = tokenizer
        self.optimizer = policy_optimizer
        self.debug = debug # Debug flag

        # --- GRPO Hyperparameters ---
        self.group_size = 4        # G: Number of samples per prompt
        self.clip_eps = 0.2        # PPO Clipping Epsilon
        self.kl_beta = 0.04        # KL Penalty Coefficient
        self.grpo_epochs = 1       # Inner optimization loops

        # Freeze Reference Model to save memory/compute
        self.ref_model.eval()
        for p in self.ref_model.parameters():
            p.requires_grad = False

    def compute_rewards(self, texts):
        """
        Computes the reward (Scalar) for each sequence in the batch.
        Assumes reward_model returns a single logit/score per sequence.
        """
        inputs = self.reward_tokenizer(
            texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=512
        ).to(DEVICE)

        if self.debug:
            print(f"  [DEBUG] Reward Model Input IDs shape: {inputs.input_ids.shape}, dtype: {inputs.input_ids.dtype}")

        with torch.no_grad():
            # Adjust based on your specific reward model architecture
            # If classifier, you might need outputs.logits[:, 1] for the positive class
            rewards = self.reward_model(inputs.input_ids, inputs.attention_mask)
            print(rewards)

            # Ensure output is flattened [Batch_Size]
            if rewards.dim() > 1:
                rewards = rewards.squeeze()

        if self.debug:
            print(f"  [DEBUG] Computed Rewards shape: {rewards.shape}, dtype: {rewards.dtype}")
        return rewards

    def _expand_prompts(self, prompts):
        """
        Duplicates prompts to create the group batch.
        [A, B] -> [A, A, A, A, B, B, B, B] (if group_size=4)
        """
        expanded_prompts = [p for p in prompts for _ in range(self.group_size)]
        if self.debug:
            print(f"  [DEBUG] Expanded prompts count: {len(expanded_prompts)}")
        return expanded_prompts

    def _rollout_responses(self, prompts):
        # 1. Expand Inputs
        group_prompts = self._expand_prompts(prompts)
        inputs = self.tokenizer(group_prompts, return_tensors='pt', padding=True).to(DEVICE)

        if self.debug:
            print(f"  [DEBUG] Rollout Input IDs shape: {inputs.input_ids.shape}, dtype: {inputs.input_ids.dtype}")

        # Determine where the prompt ends (for masking later)
        # We calculate this BEFORE generation so we know exactly where the answer starts
        prompt_lens = inputs.attention_mask.sum(dim=1).tolist()
        if self.debug:
            print(f"  [DEBUG] Prompt lengths (first 5): {prompt_lens[:5]}")

        # 2. Generate G samples per prompt
        with torch.no_grad():
            sequence = self.policy_model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=64,   # Adjust based on task
                do_sample=True,      # Essential for diversity in GRPO
                temperature=1.0,     # Standard for RLHF
                pad_token_id=self.tokenizer.pad_token_id
            )

        if self.debug:
            print(f"  [DEBUG] Generated Sequence shape: {sequence.shape}, dtype: {sequence.dtype}")

        generated_texts = self.tokenizer.batch_decode(sequence, skip_special_tokens=True)
        if self.debug:
            print(f"  [DEBUG] Generated texts (first): {generated_texts[0][:100]}...")

        # 3. Calculate Log Probs of the Generated Text (Old Policy)
        # Proper attention mask for the full sequence
        seq_attention_mask = (sequence != self.tokenizer.pad_token_id).long()

        with torch.no_grad():
            old_log_probs, final_mask = get_batch_logps(
                self.policy_model, sequence, seq_attention_mask, prompt_lens
            )

        if self.debug:
            print(f"  [DEBUG] Old Log Probs shape: {old_log_probs.shape}, dtype: {old_log_probs.dtype}")
            print(f"  [DEBUG] Final Mask shape: {final_mask.shape}, dtype: {final_mask.dtype}")

        return sequence, generated_texts, prompt_lens, old_log_probs, final_mask, seq_attention_mask

    def _calculate_group_advantage(self, rewards, num_unique_prompts):
        """
        Computes the relative advantage of each sample compared to its group mean.
        """
        # Reshape to [Num_Prompts, Group_Size]
        # e.g., if batch=8 and group=4, reshape to [2, 4]
        rewards_reshaped = rewards.view(num_unique_prompts, self.group_size)

        # Calculate stats per group
        group_means = rewards_reshaped.mean(dim=1, keepdim=True)
        group_stds = rewards_reshaped.std(dim=1, keepdim=True)

        # Normalize: (R - Mean) / Std
        # Add small epsilon to std to prevent division by zero
        advantages = (rewards_reshaped - group_means) / (group_stds + 1e-8)

        # Flatten back to [Batch_Size]
        flattened_advantages = advantages.view(-1)
        if self.debug:
            print(f"  [DEBUG] Calculated Advantages shape: {flattened_advantages.shape}, dtype: {flattened_advantages.dtype}")
        return flattened_advantages

    def _calculate_grpo_loss(self, new_log_probs, old_log_probs, advantages, token_kl, final_mask):
        """
        Calculates the surrogate loss with KL penalty.
        """
        # 1. Calculate Ratio (New / Old)
        ratio = torch.exp(new_log_probs - old_log_probs)

        # 2. Broadcast Scalar Advantage to Vector Sequence
        # advantages: [Batch] -> [Batch, 1]
        # This repeats the scalar advantage across all tokens in that sequence
        adv_expanded = advantages.unsqueeze(1)

        # 3. Standard PPO Clipping
        surr1 = ratio * adv_expanded
        surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * adv_expanded

        # 4. Combine Objective with KL Penalty
        # We subtract Beta * KL. Since we want to MAXIMIZE objective, this minimizes KL.
        # token_kl is [Batch, Length], clipped_obj is [Batch, Length]
        objective = torch.min(surr1, surr2) - (self.kl_beta * token_kl)

        # 5. Mask and Average
        # We only want to train on the Response tokens (masked by final_mask)
        # Negative sign because PyTorch optimizers minimize loss
        loss = -(objective * final_mask).sum() / final_mask.sum()

        if self.debug:
            print(f"  [DEBUG] GRPO Loss: {loss.item():.4f}")
        return loss

    def train_step(self, prompts):
        if self.debug:
            print(f"[DEBUG] Starting train_step for batch with {len(prompts)} unique prompts.")
        try:
            # --- 1. Experience Collection (Rollout) ---
            if self.debug:
                print("  [DEBUG] --- Rolling out responses ---")
            sequence, generated_texts, prompt_lens, old_log_probs, final_mask, seq_attention_mask = self._rollout_responses(prompts)

            # --- 2. Scoring ---
            if self.debug:
                print("  [DEBUG] --- Computing rewards ---")
            rewards = self.compute_rewards(generated_texts) # [Batch]

            # --- 3. Advantage Calculation (Group Relative) ---
            if self.debug:
                print("  [DEBUG] --- Calculating advantages ---")
            advantages = self._calculate_group_advantage(rewards, len(prompts)) # [Batch]

            # --- 4. Reference Model Log Probs (for KL) ---
            if self.debug:
                print("  [DEBUG] --- Calculating reference model log probs ---")
            with torch.no_grad():
                ref_log_probs, _ = get_batch_logps(
                    self.ref_model, sequence, seq_attention_mask, prompt_lens
                )
            if self.debug:
                print(f"  [DEBUG] Reference Log Probs shape: {ref_log_probs.shape}, dtype: {ref_log_probs.dtype}")

            # Calculate KL per token (Vector)
            # Approx KL: log_p_model - log_p_ref
            token_kl = old_log_probs - ref_log_probs
            if self.debug:
                print(f"  [DEBUG] Token KL shape: {token_kl.shape}, dtype: {token_kl.dtype}")

            # --- 5. Optimization Loop ---
            if self.debug:
                print("  [DEBUG] --- Starting optimization loop ---")
            batch_loss = 0.0

            for i in range(self.grpo_epochs):
                if self.debug:
                    print(f"  [DEBUG]   Optimization Epoch {i+1}/{self.grpo_epochs}")
                # Recalculate new log probs (with gradients enabled)
                new_log_probs, _ = get_batch_logps(
                    self.policy_model, sequence, seq_attention_mask, prompt_lens
                )
                if self.debug:
                    print(f"  [DEBUG]   New Log Probs shape: {new_log_probs.shape}, dtype: {new_log_probs.dtype}")

                loss = self._calculate_grpo_loss(
                    new_log_probs,
                    old_log_probs,
                    advantages,
                    token_kl,     # Vector KL
                    final_mask
                )

                if torch.isnan(loss) or torch.isinf(loss):
                    print("Warning: Loss is NaN/Inf. Skipping batch.")
                    return 0.0

                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), MAX_GRAD_NORM)
                self.optimizer.step()

                batch_loss += loss.item()
            if self.debug:
                print("  [DEBUG] --- Optimization loop finished ---")
            return batch_loss / self.grpo_epochs

        except Exception as e:
            print(f"Error in train_step: {e}")
            # Optional: Print traceback for debugging
            # import traceback; traceback.print_exc()
            self.optimizer.zero_grad()
            return 0.0


In [15]:
ppo_trainer = CustomGRPOTrainer(policy, ref, (reward_model, reward_model_tokenizer), tokenizer, policy_optimizer)

total_epoch_loss = 0.0 # This will accumulate the average batch losses
total_batches_processed = 0

for batch_idx, batch in enumerate(loader):
  prompts_only = batch['prompt']

  # train_step now returns the average PPO loss for this batch (averaged over self.ppo_epoch inner loops)
  batch_avg_loss_for_this_batch = ppo_trainer.train_step(prompts_only)

  # Accumulate batch losses and count for overall average
  total_epoch_loss += batch_avg_loss_for_this_batch
  total_batches_processed += 1

  print(f"\nBatch {batch_idx+1} completed. Average PPO Loss for this batch: {batch_avg_loss_for_this_batch:.4f}\n")
  if batch_idx >= 1:
    break

# Calculate and print overall average loss after processing all batches
if total_batches_processed > 0:
    avg_overall_training_loss = total_epoch_loss / total_batches_processed
    print(f"\nTraining completed. Overall Average PPO Loss: {avg_overall_training_loss:.4f}")
else:
    print("\nNo batches processed.")

[DEBUG] Starting train_step for batch with 5 unique prompts.
  [DEBUG] --- Rolling out responses ---
  [DEBUG] Expanded prompts count: 20
  [DEBUG] Rollout Input IDs shape: torch.Size([20, 38]), dtype: torch.int64
  [DEBUG] Prompt lengths (first 5): [13, 13, 13, 13, 13]
  [DEBUG] Generated Sequence shape: torch.Size([20, 102]), dtype: torch.int64
  [DEBUG] Generated texts (first): User: Define 'inflation' in financial terms.
Assistant: Financially, inflation is an increase in the...
  [DEBUG] Old Log Probs shape: torch.Size([20, 101]), dtype: torch.float32
  [DEBUG] Final Mask shape: torch.Size([20, 101]), dtype: torch.int64
  [DEBUG] --- Computing rewards ---
  [DEBUG] Reward Model Input IDs shape: torch.Size([20, 101]), dtype: torch.int64
tensor([2.6041, 1.7461, 2.3454, 2.3776, 2.6649, 2.7021, 2.7281, 2.8312, 1.7853,
        2.5517, 2.6861, 1.1162, 2.4688, 1.2603, 1.0587, 0.6503, 1.8653, 1.9093,
        2.3776, 2.0019], device='cuda:0')
  [DEBUG] Computed Rewards shape: torch.Size([2