PPO training


In [None]:
REWARD_MODEL_PATH = 'reward_model.pt'
REWARD_MODEL_ID = "microsoft/deberta-v3-small"
reward_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL_ID)
if reward_tokenizer.pad_token is None:
        reward_tokenizer.pad_token = reward_tokenizer.eos_token

# Instantiate the RewardModel architecture
loaded_reward_model = RewardModel(REWARD_MODEL_ID)

# Load the saved state dictionary
loaded_reward_model.load_state_dict(torch.load(REWARD_MODEL_PATH, map_location=DEVICE))

# Move the model to the appropriate device
loaded_reward_model.to(DEVICE)

# Set the model to evaluation mode
loaded_reward_model.eval()

print(f"Reward model loaded from {REWARD_MODEL_PATH} and moved to {DEVICE}.")
print("Reward tokenizer loaded and configured.")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from torch.utils.data import DataLoader, Dataset
import copy
import json
from typing import List, Dict, Any, Optional
import math
import traceback

In [3]:
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-5
MAX_LENGTH = 256
MAX_GRAD_NORM = 1.0
BATCH_SIZE = 10
REWARD_MODEL_PATH = 'reward_model.pt'
PPO_EPOCHS = 5

In [4]:
class PairwisePreferenceDataset(Dataset):
    """
    Expects list of dicts with keys: 'prompt', 'chosen', 'rejected'
    """
    def __init__(self, file_path):
        self.data = []
        with open(file_path, "r") as f:
            for line in f:
                if line.strip():                     # skip empty lines
                    self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch: List[Dict[str, str]]) -> Dict[str, Any]:
    """
    Returns a dict with lists; trainer will tokenize using the tokenizer so we keep raw strings here.
    """
    prompts = [item['prompt'] for item in batch]
    chosen = [p + "\n" + item['chosen'] for p, item in zip(prompts, batch)]
    rejected = [p + "\n" + item['rejected'] for p, item in zip(prompts, batch)]
    return {"prompt": prompts, "chosen": chosen, "rejected": rejected}

In [5]:
def get_batch_logps(model, input_ids: torch.Tensor, attention_mask: torch.Tensor, response_start_indices: List[int]) -> torch.Tensor:
    """
    Returns sum of log probabilities over response tokens for each batch element.
    input_ids: LongTensor [B, L]
    attention_mask: LongTensor [B, L]
    response_start_indices: list[int] length B; index (0-based) of first response token in unshifted input_ids
    """
    model_output = model(input_ids=input_ids, attention_mask=attention_mask)
    # logits: [B, L, V]
    logits = model_output.logits

    # Shift logits and targets for causal LM: predict token t from logits at t-1
    logits = logits[:, :-1, :].float()  # [B, L-1, V] as float32 for numerical stability
    targets = input_ids[:, 1:]           # [B, L-1]
    attn_shifted = attention_mask[:, 1:] # [B, L-1]

    # log-softmax over vocab
    log_probs = F.log_softmax(logits, dim=-1)  # [B, L-1, V]

    # gather token log probs
    token_log_probs = torch.gather(log_probs, dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)  # [B, L-1]

    # Build response mask in the full (unshifted) space, then shift it to align with token_log_probs
    B, L = input_ids.shape
    full_indices = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)  # [B, L]
    start_tensor = torch.tensor(response_start_indices, device=input_ids.device).unsqueeze(1)  # [B, 1]

    response_mask_full = (full_indices >= start_tensor).float()  # 1 for response tokens, 0 otherwise; shape [B, L]

    # Shift mask to align with token_log_probs/targets (since they are input_ids[:,1:] etc.)
    response_mask_shifted = response_mask_full[:, 1:]  # [B, L-1]

    # Combine with attention mask (only keep non-padding tokens)
    final_mask = (response_mask_shifted * attn_shifted).to(token_log_probs.dtype)  # float

    # Zero-out non-response positions (keeping negative log_probs for true tokens)
    masked_token_log_probs = token_log_probs * final_mask  # [B, L-1]

    # Sum across sequence length to produce a scalar log-prob per example
    batch_logps = masked_token_log_probs  # [B]

    return batch_logps, final_mask  # dtype: float32 on device

In [13]:
class CustomPPOTrainer:
    def __init__(self, policy_model, ref_model, reward_model, tokenizer, policy_optimizer):
        self.policy_model = policy_model
        self.ref_model = ref_model
        self.reward_model, self.reward_tokenizer = reward_model
        self.tokenizer = tokenizer
        self.optimizer = policy_optimizer
        self.clip_eps = 0.2
        self.kl_beta = 0.0
        self.debug = False
        self.ppo_epoch = PPO_EPOCHS # Initialize ppo_epoch here

        self.ref_model.eval()
        for p in self.ref_model.parameters():
            p.requires_grad = False

    def compute_rewards(self, texts):
        inputs = self.reward_tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=MAX_LENGTH).to(DEVICE)
        with torch.no_grad():
            rewards = self.reward_model(inputs.input_ids, inputs.attention_mask)
        return rewards

    def _rollout_responses(self, prompts):
        inputs = self.tokenizer(prompts, return_tensors='pt', padding=True).to(DEVICE)
        # Calculate prompt_lens BEFORE it is used
        prompt_lens = (inputs.attention_mask.sum(dim=1)).tolist()

        with torch.no_grad():
            sequence = self.policy_model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=30,
                do_sample=True,
            )
            # Now old_policy_log_probs and final_mask are computed correctly
            old_policy_log_probs, final_mask = self._compute_log_probs_and_mask(self.policy_model, sequence, prompt_lens)

        generated_texts = self.tokenizer.batch_decode(sequence, skip_special_tokens=True)
        return sequence, generated_texts, prompt_lens, old_policy_log_probs, final_mask

    def _compute_log_probs_and_mask(self, model, sequence, prompt_lens):
        attention_mask = (sequence != self.tokenizer.pad_token_id).long()
        log_probs, final_mask = get_batch_logps(model, sequence, attention_mask, prompt_lens)
        return log_probs, final_mask

    def _calculate_kl_and_advantage(self, rewards, policy_model_old_log_probs, ref_log_probs, final_mask):
        token_kl_div = policy_model_old_log_probs - ref_log_probs
        kl_div = (token_kl_div * final_mask).sum(dim=1) / final_mask.sum(dim=1)
        R_total = rewards - (self.kl_beta * kl_div)

        if R_total.std() < 1e-8:
            advantages = R_total - R_total.mean()
        else:
            advantages = (R_total - R_total.mean()) / (R_total.std() + 1e-8)
        return kl_div, advantages.unsqueeze(1)

    def _calculate_ppo_loss(self, policy_model_new_log_probs, policy_model_old_log_probs, advantages, final_mask):
        ratio = torch.exp(policy_model_new_log_probs - policy_model_old_log_probs)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
        loss = -(torch.min(surr1, surr2) * final_mask).sum()
        loss /= final_mask.sum()
        return loss, ratio

    def _ppo_optimization_loop(self, sequence, prompt_lens, policy_model_old_log_probs, advantages, final_mask, rewards, kl_div):
        total_ppo_loop_loss = 0.0
        last_ratio = None # To store the ratio from the last PPO step for printing

        for i in range(self.ppo_epoch):
            # Calculate PPO Loss
            policy_model_new_log_probs, _ = self._compute_log_probs_and_mask(self.policy_model, sequence, prompt_lens)
            loss, ratio = self._calculate_ppo_loss(policy_model_new_log_probs, policy_model_old_log_probs, advantages, final_mask)
            last_ratio = ratio # Store for later printing

            # Logging for each PPO epoch
            if self.debug:
              print(f"  PPO Epoch {i+1}/{self.ppo_epoch}")
              print(f"    advantages: {advantages.shape}, ratio: {ratio.shape}, mask: {final_mask.shape}")
              print(f"    rewards mean/std: {rewards.mean().item():.4f}/{rewards.std().item():.4f}")
              print(f"    kl_seq mean/std: {kl_div.mean().item():.4f}/{kl_div.std().item():.4f}")
              print(f"    advantages mean/std: {advantages.mean().item():.4f}/{advantages.std().item():.4f}")
              print(f"    ratio mean/std: {ratio.mean().item():.4f}/{ratio.std().item():.4f}")


            if torch.isnan(loss) or torch.isinf(loss):
                print("Warning: PPO Loss is NaN/Inf during optimization. Returning 0.0 for this batch.")
                return 0.0, None # Return 0 loss and None for ratio if invalid

            # Model updates
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), MAX_GRAD_NORM)
            self.optimizer.step()
            total_ppo_loop_loss += loss.item()
            if self.debug:
              print(f"    ppo loss = {loss.item():.4f}")

        return total_ppo_loop_loss / self.ppo_epoch if self.ppo_epoch > 0 else 0.0, last_ratio


    def train_step(self, prompts):
        try:
            # 1. Rollout
            sequence, generated_texts, prompt_lens, policy_model_old_log_probs, final_mask = self._rollout_responses(prompts)

            # 2a. Compute Rewards
            rewards = self.compute_rewards(generated_texts)

            # 2b. Calculate log_probs for ref model
            with torch.no_grad():
                ref_log_probs, _ = self._compute_log_probs_and_mask(self.ref_model, sequence, prompt_lens)

            # 3. Calculate Advantage
            kl_div, advantages = self._calculate_kl_and_advantage(rewards, policy_model_old_log_probs, ref_log_probs, final_mask)

            # Perform PPO optimization steps
            batch_avg_ppo_loss, last_ratio = self._ppo_optimization_loop(sequence, prompt_lens, policy_model_old_log_probs, advantages, final_mask, rewards, kl_div)

            if torch.isnan(torch.tensor(batch_avg_ppo_loss)) or torch.isinf(torch.tensor(batch_avg_ppo_loss)):
                print("Warning: Batch average PPO Loss is NaN/Inf. Skipping batch and returning 0.0.")
                return 0.0

            return batch_avg_ppo_loss # Return the average loss for this batch (averaged over PPO_EPOCHS inner loops)

        except Exception as e:
            print(f"Error in train_step: {e}")
            traceback.print_exc() # Print full traceback for debugging
            self.optimizer.zero_grad()
            return 0.0

In [6]:
dataset = PairwisePreferenceDataset('financial_rewards_500.jsonl')
loader = DataLoader(dataset, batch_size=5, shuffle=True, collate_fn=collate_fn)

In [7]:

MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-5
MAX_LENGTH = 256
MAX_GRAD_NORM = 1.0
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)            # replace with your tokenizer
policy = AutoModelForCausalLM.from_pretrained(MODEL_ID)       # replace with your policy model
ref = AutoModelForCausalLM.from_pretrained(MODEL_ID)
policy_optimizer = optim.AdamW(policy.parameters(), lr=LEARNING_RATE)


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

In [8]:
policy.to(DEVICE)
ref.to(DEVICE)
reward_model.to(DEVICE)

RewardModel(
  (model): DebertaV2ForSequenceClassification(
    (deberta): DebertaV2Model(
      (embeddings): DebertaV2Embeddings(
        (word_embeddings): Embedding(128100, 768, padding_idx=0)
        (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): DebertaV2Encoder(
        (layer): ModuleList(
          (0-5): 6 x DebertaV2Layer(
            (attention): DebertaV2Attention(
              (self): DisentangledSelfAttention(
                (query_proj): Linear(in_features=768, out_features=768, bias=True)
                (key_proj): Linear(in_features=768, out_features=768, bias=True)
                (value_proj): Linear(in_features=768, out_features=768, bias=True)
                (pos_dropout): Dropout(p=0.1, inplace=False)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): DebertaV2SelfOutput(
                (dense): Linear(in_features=768, o

In [9]:
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

reward_tokenizer.padding_side = "left"
reward_tokenizer.pad_token = reward_tokenizer.eos_token

In [None]:
policy.config.pad_token_id = tokenizer.pad_token_id
policy.config.eos_token_id = tokenizer.pad_token_id
ref.config.pad_token_id = tokenizer.pad_token_id
ref.config.eos_token_id = tokenizer.pad_token_id
print("Model token IDs aligned.")

In [14]:
ppo_trainer = CustomPPOTrainer(policy, ref, (reward_model, reward_tokenizer), tokenizer, policy_optimizer)

total_epoch_loss = 0.0 # This will accumulate the average batch losses
total_batches_processed = 0

for batch_idx, batch in enumerate(loader):
  prompts_only = batch['prompt']

  # train_step now returns the average PPO loss for this batch (averaged over self.ppo_epoch inner loops)
  batch_avg_loss_for_this_batch = ppo_trainer.train_step(prompts_only)

  # Accumulate batch losses and count for overall average
  total_epoch_loss += batch_avg_loss_for_this_batch
  total_batches_processed += 1

  print(f"\nBatch {batch_idx+1} completed. Average PPO Loss for this batch: {batch_avg_loss_for_this_batch:.4f}\n")
  if batch_idx >= 5:
    break

# Calculate and print overall average loss after processing all batches
if total_batches_processed > 0:
    avg_overall_training_loss = total_epoch_loss / total_batches_processed
    print(f"\nTraining completed. Overall Average PPO Loss: {avg_overall_training_loss:.4f}")
else:
    print("\nNo batches processed.")


Batch 1 completed. Average PPO Loss for this batch: -0.0955


Batch 2 completed. Average PPO Loss for this batch: -0.2080


Batch 3 completed. Average PPO Loss for this batch: -0.0541


Batch 4 completed. Average PPO Loss for this batch: 0.1392


Batch 5 completed. Average PPO Loss for this batch: -0.0751


Batch 6 completed. Average PPO Loss for this batch: -0.1645


Training completed. Overall Average PPO Loss: -0.0763


In [15]:
print("\n--- Testing PPO Policy Model ---")

# Example prompts for testing
test_prompts = [
    "User: What is a safe investment strategy?",
    "User: Explain the concept of inflation to a beginner.",
    "User: What are the risks of investing in cryptocurrency?"
]

# Set the policy model to evaluation mode
policy.eval()

# Tokenize prompts
inputs = tokenizer(test_prompts, return_tensors='pt', padding=True, truncation=True, max_length=MAX_LENGTH).to(DEVICE)

# Generate responses
with torch.no_grad():
    generated_sequences = policy.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        max_new_tokens=50, # Generate up to 50 new tokens
        do_sample=True,
        temperature=0.7, # Control randomness
        top_k=50,        # Top-k sampling
        pad_token_id=tokenizer.pad_token_id, # Ensure pad token is correctly set for generation
        eos_token_id=tokenizer.eos_token_id  # Ensure eos token is correctly set for generation
    )

# Decode and print results
for i, prompt in enumerate(test_prompts):
    # The generated sequence includes the prompt tokens. We need to decode the full sequence
    # and then potentially extract only the response part if the prompt is explicitly part of the output.
    # For models like Qwen that might echo the prompt, we can try to find the start of the new text.
    generated_text = tokenizer.decode(generated_sequences[i], skip_special_tokens=True)

    # Attempt to remove the original prompt from the generated text
    # This is a heuristic and might need adjustment based on the model's exact output format
    if prompt in generated_text:
        response_start_index = generated_text.find(prompt) + len(prompt)
        response_text = generated_text[response_start_index:].strip()
    else:
        response_text = generated_text.strip() # Fallback if prompt is not found directly

    print(f"\n--- Prompt {i+1} ---")
    print(f"Prompt: {prompt}")
    print(f"Generated Response: {response_text}")

print("\n--- Testing Complete ---")


--- Testing PPO Policy Model ---

--- Prompt 1 ---
Prompt: User: What is a safe investment strategy?
Generated Response: In the context of financial terms, a safe investment strategy is an approach used by individuals to manage risks associated with their position. The strategy involves maintaining some level of exposure to common stocks while also holding other assets that have lower correlation with common stocks (e

--- Prompt 2 ---
Prompt: User: Explain the concept of inflation to a beginner.
Generated Response: Inflation is defined as a period of prolonged economic decline characterized by higher prices over time. When used as a term in relation to currency, it is defined as a period of extended price increase that can be caused by factors such as higher interest rates or

--- Prompt 3 ---
Prompt: User: What are the risks of investing in cryptocurrency?
Generated Response: In addition, please provide a step-by-step guide for identifying potential scams related to cryptocurrencies