In [1]:
%pwd

'd:\\software_3\\Generative_models\\Text_models\\chat_gpt2\\RLHF_DPO_Preference_alignment\\RLHF_with_GRPO'

In [2]:
import os

os.chdir("../../")

In [3]:
%pwd

'd:\\software_3\\Generative_models\\Text_models\\chat_gpt2'

# Grop Relative Policy Optimization (GRPO)

This Notebook implements the GRPO RLHF tunning based on this [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](http://arxiv.org/abs/2402.03300) research paper.

In [4]:
import os
import time
import tiktoken
import urllib
import urllib.request 
import json
import torch
import numpy as np
import pandas as pd
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import tensorflow as tf
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from dataclasses import dataclass
from fileinput import filename
from sqlite3 import paramstyle

from gpt import GPTModel
from model_args import BASE_CONFIG
from utils.generate import generate
from utils.download_dataset import download_and_load_dataset, partition_data
from utils.train import train_model
from utils.load_and_save_models import load_model, save_model
from utils.token_converter import get_tokenizer, text_to_token_ids, token_ids_to_text

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
def format_input(entry):

    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that appropriately complates the request."
        f"\n\n## Instruction:\n{entry["prompt"]}"
    )
    return instruction_text
    

In [7]:
class GRPODataset(Dataset):

    def __init__(self, data, tokenizer):
        self.data = data
        
        self.encoded_text = []
        for entry in data:
            prompt = format_input(entry)

            prompt_tokens = text_to_token_ids(prompt, tokenizer)

            # Ensure prompt_tokens is a flat tensor
            if prompt_tokens.numel() > 0 and prompt_tokens.ndim > 1:
                # If it's a nested tensor, flatten it
                prompt_tokens = prompt_tokens.flatten()

            self.encoded_text.append(prompt_tokens)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.encoded_text[index]

In [8]:
def custom_collate_fun(
    batch,
    pad_token_id=50256,
    allowed_max_length=None,
    mask_prompt_tokens=True,
    device="cpu"
):
    batch_data = {
        "prompt": []
    }

    max_length_comman = 0
    if batch:
        current_max_length = max(len(item) + 1 for item in batch)  # item is a list
        max_length_comman = max(max_length_comman, current_max_length)

    for item in batch:
        prompt = torch.tensor(item)
        batch_data["prompt"].append(prompt)

    # Pad sequences to the same length
    padded_prompts = nn.utils.rnn.pad_sequence(batch_data["prompt"], batch_first=True, padding_value=pad_token_id)
    if allowed_max_length is not None:
        padded_prompts = padded_prompts[:, :allowed_max_length]

    batch_data["prompt"] = padded_prompts.to(device)

    return batch_data

In [9]:
customized_collate_fn = partial(
    custom_collate_fun,
    device=device,
    allowed_max_length=1024
)

In [10]:
def split_data(data):
    train_portion = int(len(data) * 0.8)
    test_portion = int(len(data) * 0.1)
    val_portion = len(data) - train_portion - test_portion

    train_data = data[:train_portion]
    test_data = data[train_portion:train_portion + test_portion]
    val_data = data[train_portion + test_portion:]

    print("train dataset length:", len(train_data))
    print("length of test data:", len(test_data))
    print("length of val data:", len(val_data))

    return train_data, test_data, val_data


In [11]:
def create_dataloaders(
    train_data,
    test_data,
    val_data,
    batch_size,
    tokenizer
):

    num_workers = 0
    batch_size = 8

    train_data = GRPODataset(train_data, tokenizer)
    train_dataloader = DataLoader(
    train_data,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers)

    test_dataset = GRPODataset(test_data, tokenizer)
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        collate_fn=customized_collate_fn,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers
    )

    val_dataset = GRPODataset(val_data, tokenizer)
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        collate_fn=customized_collate_fn,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers
    )

    return train_dataloader, test_dataloader, val_dataloader

## GRPO Loss Function

In [12]:
class AdaptiveKLController:

    def __init__(self, kl_coef=0.01, target_kl=6.0):
        self.kl_coef = kl_coef
        self.target_kl = target_kl

    def update(self, current_kl, n_steps):
        proportional_error = np.clip((current_kl / self.target_kl) - 1.0, -0.2, 0.2)
        multiplier = 1.0 + (proportional_error * n_steps / 20.0)
        self.kl_coef *= multiplier


In [13]:
def calculate_logprobs(sequences, scores):
    """
    Calculates log-probabilities of the next tokens for multiple agents in batched format.

    Args:
        sequences (Tensor): shape (batch_size, seq_len)
        scores (Tensor): shape (batch_size, seq_len - 1, vocab_size) — model logits (before softmax)

    Returns:
        logprobs (Tensor): shape (batch_size, seq_len - 1) — log-probability of each next token
    """
    # Get the next tokens (targets) at each position
    next_tokens = sequences[:, 1:]  # shape: (batch_size, seq_len - 1)

    # Compute log-softmax over vocab dimension
    log_probs = F.log_softmax(scores, dim=-1)  # shape: (batch_size, seq_len - 1, vocab_size)

    # Gather the log-prob of the actual next token
    logprobs = torch.gather(log_probs, dim=2, index=next_tokens.unsqueeze(-1)).squeeze(-1)
    # shape: (batch_size, seq_len - 1)

    return logprobs


In [14]:
def compute_rewards(reward_model, responses, tokenizer, config, device=None):

    if device is None:
        device = next(reward_model.parameters()).device

    rewards = []
    with torch.no_grad():
        for response in responses:
            input_ids = text_to_token_ids(response, tokenizer)
            if isinstance(input_ids, list):
                input_ids = torch.tensor(input_ids, dtype=torch.long)
            if input_ids.ndim == 1:
                input_ids = input_ids.unsqueeze(0)
            input_ids = input_ids.to(device)

            reward = reward_model(input_ids)

            if reward.ndim > 1:
                reward = reward[0, -1]

            reward = reward.unsqueeze(0)
            rewards.append(reward)

    rewards = torch.cat(rewards, dim=0)

    if getattr(config, "use_score_norm", False):
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)

    if getattr(config, "use_score_scaling", False):
        rewards = torch.clamp(rewards, -config.score_clip, config.score_clip)

    return rewards

In [15]:
def compute_returns(rewards, gamma=0.99):
    """
    Compute discounted returns for a batch of episodes.
    Args:
        rewards (Tensor): shape (batch_size, episode_len) or (batch_size,)
        gamma (float): discount factor
    Returns:
        Tensor: discounted returns, shape (batch_size, episode_len) or (batch_size,)
    """
    if rewards.ndim == 1:
        # Single-step episode: returns = rewards
        return rewards
    batch_size, episode_len = rewards.shape
    returns = torch.zeros_like(rewards)
    for t in reversed(range(episode_len)):
        if t == episode_len - 1:
            returns[:, t] = rewards[:, t]
        else:
            returns[:, t] = rewards[:, t] + gamma * returns[:, t + 1]
    return returns

In [16]:
def compute_normalized_grpo_advantages(agent_rewards, gamma=0.99, baseline_type="mean", eps=1e-8):
    """
    Compute normalized group-relative advantages for GRPO.
    Args:
        agent_rewards (List[Tensor]): list of rewards for each agent, each of shape (batch_size,) or (batch_size, episode_len)
        ...
    """
    num_agents = len(agent_rewards)
    agent_returns = [compute_returns(r, gamma) for r in agent_rewards]  # list of tensors: [batch_size] or [batch_size, episode_len]
    returns = torch.stack(agent_returns)
    if baseline_type == "mean":
        group_baseline = returns.mean(dim=0)
    elif baseline_type == "median":
        group_baseline = returns.median(dim=0).values
    else:
        raise ValueError("baseline_type must be 'mean' or 'median'")
    raw_advantages = [agent_returns[i] - group_baseline for i in range(num_agents)]
    all_adv = torch.stack(raw_advantages)
    flat_adv = all_adv.view(-1)
    mean_adv = flat_adv.mean()
    std_adv = flat_adv.std()
    normalized_advantages = [(adv - mean_adv) / (std_adv + eps) for adv in raw_advantages]
    return normalized_advantages

In [17]:
import torch
import torch.nn.functional as F

def generate_responses(prompts, policy_model, tokenizer):
    max_context = 1024
    responses = []
    all_log_probs = []

    with torch.no_grad():
        for prompt in prompts:
            if isinstance(prompt, list):
                prompt = torch.tensor(prompt, dtype=torch.long)
            if prompt.ndim == 1:
                prompt = prompt.unsqueeze(0)

            prompt = prompt.to(policy_model.tok_emb.weight.device)
            max_new_tokens = min(100, max_context - prompt.shape[1])

            # Generate a response using your autoregressive sampling method
            output = generate(
                model=policy_model,
                idx=prompt,
                max_new_tokens=max_new_tokens,
                context_size=max_context,
                eos_id=50256
            )

            generated_ids = output[0] if isinstance(output, tuple) else output
            if generated_ids.ndim == 1:
                generated_ids = generated_ids.unsqueeze(0)

            response = token_ids_to_text(generated_ids, tokenizer)
            responses.append(response)

            logits = policy_model(generated_ids)
            if logits.ndim == 2:
                logits = logits.unsqueeze(0)

            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = generated_ids[..., 1:].contiguous()
            logprobs_tensor = F.log_softmax(shift_logits, dim=-1)
            logprobs_for_labels = logprobs_tensor.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)

            # Mean log-prob over generated tokens
            all_log_probs.append(logprobs_for_labels.mean())

    return responses, torch.stack(all_log_probs)


In [18]:
import torch
import torch.nn.functional as F

def forward_pass(responses, policy_model, tokenizer):
    logprobs = []

    for response in responses:
        # Tokenize the response
        input_ids = text_to_token_ids(response, tokenizer)
        if isinstance(input_ids, list):
            input_ids = torch.tensor(input_ids, dtype=torch.long)
        if input_ids.ndim == 1:
            input_ids = input_ids.unsqueeze(0)
        input_ids = input_ids.to(next(policy_model.parameters()).device)

        # Forward pass through the policy model
        logits = policy_model(input_ids)  # shape: [B, T, Vocab]
        if logits.ndim == 2:
            logits = logits.unsqueeze(0)

        # Align logits and labels for next-token prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = input_ids[..., 1:].contiguous()

        # Log-softmax for log-probs over vocab
        logprobs_tensor = F.log_softmax(shift_logits, dim=-1)

        # Gather log-probs of actual tokens
        logprobs_for_labels = logprobs_tensor.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)

        # Mean log-prob for the response
        logprobs.append(logprobs_for_labels.mean())

    return torch.stack(logprobs)


In [19]:
def grpo_update(
    policy_model,
    policy_optimizer,
    kl_ctl,
    stats,
    tokenizer,
    responses,
    old_logprobs,
    rewards,
    advantages,
    config
):
    batch_size = len(responses)
    indices = torch.randperm(batch_size)

    policy_losses = []
    kl_divergence = []

    device = next(policy_model.parameters()).device  # <-- Add this line

    for epoch in range(config.grpo_epochs):
        for i in range(0, batch_size, config.mini_batch_size):
            mini_batch_indices = indices[i:i + config.mini_batch_size]

            mini_responses = [responses[idx] for idx in mini_batch_indices]
            mini_old_logprobs = old_logprobs[mini_batch_indices]
            mini_advantages = advantages[mini_batch_indices]

            policy_model.train()

            new_logprobs, _ = forward_pass(mini_responses, policy_model, tokenizer)

            ratio = torch.exp(new_logprobs - mini_old_logprobs)  # π/π_old
            surr1 = ratio * mini_advantages
            surr2 = torch.clamp(ratio, 1 - config.clip_range, 1 + config.clip_range) * mini_advantages
            policy_loss = -torch.min(surr1, surr2).mean()

            kl_div = (mini_old_logprobs - new_logprobs).mean()

            # Value loss (optional — only if using a value head)
            value_loss = torch.tensor(0.0, device=device)  # <-- Use device here

            total_loss = policy_loss + config.vf_coef * value_loss + kl_ctl.kl_coef * kl_div

            policy_optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), config.max_grad_norm)
            policy_optimizer.step()

            policy_losses.append(policy_loss.item())
            kl_divergence.append(kl_div.item())

    

    # Update KL controller
    mean_kl = np.mean(kl_divergence)
    kl_ctl.update(mean_kl, batch_size)

    # Update training stats
    stats['policy_loss'].extend(policy_losses)
    stats['kl_divergence'].extend(kl_divergence)
    stats['rewards'].extend(rewards.tolist())
    stats['advantages'].extend(advantages.tolist())

    return {
        'policy_loss': np.mean(policy_losses),
        'kl_divergence': mean_kl,
        'mean_rewards': rewards.mean().item(),
        'mean_advantages': advantages.mean().item()
    }


In [20]:
def train_step(
    batch,
    policy_model,
    reward_model,
    policy_optimizer,
    kl_ctl,
    stats,
    tokenizer,
    config
):
    # Step 1: Generate responses and logprobs using current policy
    prompts = batch['prompt']
    responses, old_logprobs = generate_responses(prompts, policy_model, tokenizer)

    # Step 2: Compute rewards using the reward model
    rewards = compute_rewards(reward_model, responses, tokenizer, config)
    # rewards: shape [batch_size]

    # Step 3: Compute group-normalized advantages (GRPO specific)
    normalized_advantages = compute_normalized_grpo_advantages([rewards])
    normalized_advantages = normalized_advantages[0]

    # Step 4: Perform GRPO policy update
    metrics = grpo_update(
        policy_model=policy_model,
        policy_optimizer=policy_optimizer,
        kl_ctl=kl_ctl,
        stats=stats,
        tokenizer=tokenizer,
        responses=responses,
        old_logprobs=old_logprobs,
        rewards=rewards,
        advantages=normalized_advantages,
        config=config
    )

    return metrics

In [21]:
import torch
import numpy as np
from tqdm import tqdm

class GRPOTrainer:

    def __init__(
        self,
        config,
        policy_model,
        ref_model,
        reward_model,
        tokenizer
    ):
        self.config = config
        self.policy_model = policy_model
        self.ref_model = ref_model
        self.reward_model = reward_model
        self.tokenizer = tokenizer

        # Freeze reference model parameters
        for param in self.ref_model.parameters():
            param.requires_grad = False

        self.policy_optimizer = torch.optim.Adam(
            self.policy_model.parameters(), lr=self.config.learning_rate
        )

        self.kl_ctl = AdaptiveKLController(self.config.kl_coef, self.config.kl_target)

        self.stats = {
            'policy_loss': [],
            'kl_divergence': [],
            'rewards': [],
            'advantages': []
        }

    def train(self, dataloader, num_epochs=1):

        for epoch in range(num_epochs):
            epoch_metrics = []

            for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}")):
                metrics = train_step(
                    batch,
                    self.policy_model,
                    self.reward_model,
                    self.policy_optimizer,
                    self.kl_ctl,
                    self.stats,
                    self.tokenizer,
                    self.config
                )
                epoch_metrics.append(metrics)

                if batch_idx % 10 == 0:
                    avg_metrics = {k: np.mean([m[k] for m in epoch_metrics[-10:]]) for k in metrics.keys()}
                    print(f"Batch {batch_idx}: {avg_metrics}")

            # Epoch summary
            epoch_avg = {k: np.mean([m[k] for m in epoch_metrics]) for k in epoch_metrics[0].keys()}
            print(f"Epoch {epoch+1} Summary: {epoch_avg}")


In [22]:
@dataclass
class GRPOConfig:
    learning_rate: float = 1e-6
    batch_size: int = 2
    mini_batch_size: int = 2
    grpo_epochs: int = 4
    max_grad_norm: float = 1.0
    clip_range: float = 0.2
    clip_range_vf: float = 0.2
    vf_coef: float = 0.1
    kl_coef: float = 0.01
    target_kl: float = 0.1
    max_length: int = 512
    temperature: float = 1.0
    top_k: int = 50
    top_p: float = 0.95
    gamma: float = 0.99
    lam: float = 0.95
    entropy_coef: float = 0.01
    adaptive_kl: bool = True
    kl_target: float = 6.0
    horizon: int = 10000
    use_score_scaling: bool = True
    use_score_norm: bool = True
    score_clip: float = 5.0

In [23]:
config = GRPOConfig(
    learning_rate=1e-6,
    batch_size=2,
    grpo_epochs=2,
    max_length=1024
)

In [24]:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = get_tokenizer()

In [25]:
model_path = "gpt_models\\SFT_model.pth"
policy_model = GPTModel(BASE_CONFIG)
policy_model.load_state_dict(
    torch.load(
    model_path,
    map_location=torch.device(device),
    weights_only=True
    )['model_state_dict']
)
policy_model.to(device)
policy_model.eval()

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=7

In [26]:
reference_model = GPTModel(BASE_CONFIG)
reference_model.load_state_dict(
    torch.load(
    model_path,
    map_location=torch.device(device),
    weights_only=True
    )['model_state_dict']
)
reference_model.to(device)
reference_model.eval()

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=7

In [27]:
reward_model = GPTModel(BASE_CONFIG)
num_classes = 1
reward_model.out_head = torch.nn.Linear(
    in_features=BASE_CONFIG["emb_dim"], 
    out_features=num_classes
)
model_path = "gpt_models\\RLHF_reward_model.pth"
reward_model.load_state_dict(
    torch.load(
    model_path,
    map_location=torch.device(device),
    weights_only=True
    )
)
reward_model.to(device)
reward_model.eval()

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=7

In [28]:
with open("oasst1.json", "r", encoding="utf-8") as f:
    data = json.load(f)
train_data, test_data, val_data = split_data(data)
train_dataloader, test_dataloader, val_dataloader = create_dataloaders(
    train_data,
    test_data,
    val_data,
    config.batch_size,
    tokenizer
)

train dataset length: 4000
length of test data: 500
length of val data: 500


In [None]:
trainer = GRPOTrainer(
    config,
    policy_model,
    reference_model,
    reward_model,++++
    tokenizer
)

In [None]:
NUM_EPOCHS = 1

trainer.train(train_dataloader, num_epochs=NUM_EPOCHS)
torch.save(policy_model.state_dict(), 'GRPO_trained_model.pth')
print("-------------------------------------------------------------------")
print()
print("Training has been completed and model has beed saved")


  prompt = torch.tensor(item)
Epoch 1:   0%|          | 1/500 [02:23<19:52:51, 143.43s/it]

Batch 0: {'policy_loss': np.float64(0.0), 'kl_divergence': np.float64(0.44504278898239136), 'mean_rewards': np.float64(-2.9802322387695312e-08), 'mean_advantages': np.float64(0.0)}


Epoch 1:   0%|          | 2/500 [37:23<179:01:19, 1294.14s/it]