In [1]:
%pwd

'd:\\software_3\\Generative_models\\Text_models\\chat_gpt2\\RLHF_DPO_Preference_alignment\\RLHF_with_PPO'

In [2]:
import os

os.chdir("../../")

In [3]:
%pwd

'd:\\software_3\\Generative_models\\Text_models\\chat_gpt2'

# Proximal Policy Optimization (PPO)

This Notebook implements the PPO RLHF from this [Pairwise Proximal Policy Optimization: Harnessing Relative Feedback for LLM Alignment](https://arxiv.org/abs/2310.00212) research paper.

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass

from gpt import GPTModel
from model_args import BASE_CONFIG
from utils.generate import generate
from utils.token_converter import get_tokenizer, text_to_token_ids, token_ids_to_text

In [5]:
import json

with open("oasst1.json", "r", encoding="utf-8") as f:
    data = json.load(f)

# Now you can access the data as a list of dictionaries
print(data[0]["prompt"])  # Prints the prompt of the first object

Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.


In [6]:
def format_input(entry):

    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that appropriately complates the request."
        f"\n\n## Instruction:\n{entry["prompt"]}"
    )
    return instruction_text

In [7]:
class PPODataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data

        self.encoded_text = []
        for entry in data:
            prompt = format_input(entry)

            prompt_tokens = tokenizer.encode(prompt)

            self.encoded_text.append(prompt_tokens)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.encoded_text[index]       

In [8]:
def custom_collate_fun(
    batch,
    pad_token_id=50256,
    allowed_max_length=None,
    mask_prompt_tokens=True,
    device="cpu"
):
    batch_data = {
        "prompt": []
    }

    max_length_comman = 0
    if batch:
        current_max_length = max(len(item) + 1 for item in batch)  # item is a list
        max_length_comman = max(max_length_comman, current_max_length)

    for item in batch:
        prompt = torch.tensor(item)
        batch_data["prompt"].append(prompt)

    # Pad sequences to the same length
    padded_prompts = nn.utils.rnn.pad_sequence(batch_data["prompt"], batch_first=True, padding_value=pad_token_id)
    if allowed_max_length is not None:
        padded_prompts = padded_prompts[:, :allowed_max_length]

    batch_data["prompt"] = padded_prompts.to(device)

    return batch_data

In [9]:
from functools import partial
device = "cuda" if torch.cuda.is_available() else "cpu"

customized_collate_fn = partial(
    custom_collate_fun,
    device=device,
    allowed_max_length=1024
)

In [10]:
example_data = data[:5]

In [11]:
tokenizer = get_tokenizer()

example_dataset = PPODataset(example_data, tokenizer)

example_dataloader = DataLoader(
    example_dataset,
    batch_size=2,
    collate_fn = customized_collate_fn,
    shuffle=False
)

In [12]:
for batch in example_dataloader:
    break

print("batch.keys:", batch.keys())

batch.keys: dict_keys(['prompt'])


In [13]:
batch["prompt"]

tensor([[21106,   318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,
           257,  2882,   326, 20431,  2299,   689,   262,  2581,    13,   198,
           198,  2235, 46486,    25,   198,  6090,   345,  3551,   257,  1790,
          9793,   546,   262, 23082,   286,   262,  3381,   366,  2144,   404,
          1559,    88,     1,   287, 12446,    30,  4222,   779,  6096,  3519,
           284,  2785, 15848,  1559,   444,   287,   262, 10515,  1910,   290,
         21729,  5981,  2267,    13, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 5

In [14]:
token_ids_to_text(batch["prompt"][0], tokenizer)

'Below is an instruction that describes a task. Write a response that appropriately complates the request.\n\n## Instruction:\nCan you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|e

In [15]:
def split_data(data):
    train_portion = int(len(data) * 0.85)
    test_portion = int(len(data) * 0.1)
    val_portion = len(data) - train_portion - test_portion

    train_data = data[:train_portion]
    test_data = data[train_portion:train_portion + test_portion]
    val_data = data[train_portion + test_portion:]
    
    print("Training set length:", len(train_data))
    print("Validation set length:", len(val_data))
    print("Test set length:", len(test_data))

    return train_data, test_data, val_data

train_data, test_data, val_data = split_data(data)

Training set length: 4250
Validation set length: 250
Test set length: 500


In [16]:
num_workers = 0
batch_size = 8

train_dataset = PPODataset(train_data, tokenizer)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers)

test_dataset = PPODataset(test_data, tokenizer)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

val_dataset = PPODataset(val_data, tokenizer)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

In [17]:
model_path = "gpt_models\\SFT_model.pth"
model = GPTModel(BASE_CONFIG)
model.load_state_dict(
    torch.load(
    model_path,
    map_location=torch.device("cpu"),
    weights_only=True
    )['model_state_dict']
)
model.eval()

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=7

In [18]:
prompt = """Below is an instruction that describes a task. Write a response
that appropriately completes the request.

### Instruction:
Convert the active sentence to passive: 'The chef cooks the meal every day.'
"""

In [19]:
token_ids = generate(
    model=model,
    idx=text_to_token_ids(data[0]["prompt"], tokenizer),
    max_new_tokens=35,
    context_size=BASE_CONFIG["context_length"],
    eos_id=50256
)

response = token_ids_to_text(token_ids, tokenizer)
print(response)

Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.

### Input:
Generate a sentence using the word "economics". 

### Response:
The researchers had a dynamic economy that led to the success


In [20]:
policy_model = model

reference_model = GPTModel(BASE_CONFIG)
reference_model.load_state_dict(
    torch.load(
    model_path,
    map_location=torch.device("cpu"),
    weights_only=True
    )['model_state_dict']
)
reference_model.eval()

policy_model.to(device)
reference_model = model.to(device)

In [21]:
reward_model = GPTModel(BASE_CONFIG)
num_classes = 1
reward_model.out_head = torch.nn.Linear(
    in_features=BASE_CONFIG["emb_dim"], 
    out_features=num_classes
)
model_path = "gpt_models\\RLHF_reward_model.pth"
reward_model.load_state_dict(
    torch.load(
    model_path,
    map_location=torch.device("cpu"),
    weights_only=True
    )
)
reward_model.to(device).eval()

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=7

In [22]:
def setup_value_model(base_model, base_config, device):
    """Setup reward model from pre-trained base model"""
    
    # Freeze all parameters first
    for param in base_model.parameters():
        param.requires_grad = False
    
    # Replace output head with reward head (single scalar output)
    num_classes = 1
    base_model.out_head = torch.nn.Linear(
        in_features=base_config["emb_dim"], 
        out_features=num_classes
    )
    
    # Unfreeze last transformer block for fine-tuning
    for param in base_model.trf_blocks[-1].parameters():
        param.requires_grad = True
    
    # Unfreeze final normalization layer
    for param in base_model.final_norm.parameters():
        param.requires_grad = True
    
    # Unfreeze the new output head
    for param in base_model.out_head.parameters():
        param.requires_grad = True
    
    base_model.to(device)
    return base_model

In [23]:
model_path = "gpt_models\\SFT_model.pth"
value_model = GPTModel(BASE_CONFIG)
value_model.load_state_dict(
    torch.load(
    model_path,
    map_location=torch.device("cpu"),
    weights_only=True
    )['model_state_dict']
)
value_model = setup_value_model(base_model=value_model, base_config=BASE_CONFIG, device=device)

In [24]:
print(value_model)

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=7

## PPO Loss

In [25]:
def ensure_max_context(input_ids, max_context=1024):
    if input_ids.shape[1] > max_context:
        input_ids = input_ids[:, -max_context:]
    return input_ids

In [26]:
class AdaptiveKLController:

    def __init__(self, kl_coef=0.01, target_kl=6.0):
        self.kl_coef = kl_coef
        self.target = target_kl

    def update(self, current_kl, n_steps):
        proportional_error = np.clip(current_kl / self.target -1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / 20
        self.kl_coef *= mult

In [None]:
def calculate_logprobs(sequence, scores):
    logprobs = []

    for i, score in enumerate(scores):
        if i < len(sequence[0]) -1:
            next_token = sequence[0][i+1]
            logprob = F.log_softmax(score, dim= -1)[0, next_token]
            logprobs.append(logprob)
    return torch.stack(logprobs)

In [28]:
import torch

def compute_rewards(reward_model, responses, tokenizer, config, device=None):
    """
    Computes scalar rewards for each response using the reward model.

    Args:
        reward_model: The reward model (should output a tensor per input).
        responses: List of response strings.
        tokenizer: Tokenizer to convert text to token ids.
        config: PPOConfig object (for normalization/scaling).
        device: torch.device (optional, if not provided, uses reward_model's device).

    Returns:
        rewards: Tensor of shape [batch_size, 1] (one reward per response).
    """
    if device is None:
        device = next(reward_model.parameters()).device
    
    rewards = []
    with torch.no_grad():
        for response in responses:
            # Convert response to token ids
            input_ids = text_to_token_ids(response, tokenizer)
            if isinstance(input_ids, list):
                input_ids = torch.tensor(input_ids, dtype=torch.long)
            if input_ids.ndim == 1:
                input_ids = input_ids.unsqueeze(0)  # [1, seq_len]
            input_ids = input_ids.to(device)

            # Get reward from model
            input_ids = ensure_max_context(input_ids)
            reward = reward_model(input_ids)
            # If reward is [1, seq_len, 1] or [1, seq_len], take last token's reward
            if reward.ndim > 1:
                reward = reward[0, -1]
            # Ensure reward is a scalar tensor
            reward = reward.unsqueeze(0)  # [1]
            rewards.append(reward)

    rewards = torch.cat(rewards, dim=0)  # [batch_size, 1]

    # Optional: normalize and/or clip rewards
    if getattr(config, "use_score_norm", False):
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
    if getattr(config, "use_score_scaling", False):
        rewards = torch.clamp(rewards, -config.score_clip, config.score_clip)

    return rewards

In [29]:
def compute_advantages(rewards, values, config):
    advantages = []
    returns = []

    rewards_np = rewards.detach().cpu().numpy()
    values_np = values.detach().cpu().numpy()

    gae = 0
    for i in reversed(range(len(rewards_np))):
        if i == len(rewards_np) -1:
            next_value = 0
        else:
            next_value = values_np[i + 1]

        delta = rewards_np[i] + config.gamma * next_value - values_np[i]
        gae = delta + config.gamma * config.lam * gae
        advantages.insert(0, gae)
    
    advantages = torch.tensor(advantages, device=device)
    returns = advantages + values

    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    return advantages, returns

In [30]:
def generate_responses(prompts, policy_model, value_model, tokenizer):
    policy_model.eval()
    MAX_CONTEXT = 1024
    responses = []
    all_logprobs = []
    all_values = []

    with torch.no_grad():
        for prompt in prompts:
            prompt_text = token_ids_to_text(prompt, tokenizer)
            input_ids = text_to_token_ids(prompt_text, tokenizer)
            if isinstance(input_ids, list):
                input_ids = torch.tensor(input_ids, dtype=torch.long)
            if input_ids.ndim == 1:
                input_ids = input_ids.unsqueeze(0)
            input_ids = ensure_max_context(input_ids, MAX_CONTEXT)
            input_ids = input_ids.to(policy_model.tok_emb.weight.device)

            max_new_tokens = min(100, MAX_CONTEXT - input_ids.shape[1])
            output = generate(
                model=policy_model,
                idx=input_ids,
                max_new_tokens=max_new_tokens,
                context_size=MAX_CONTEXT,
                eos_id=50256
            )
            generated_ids = output[0]
            if generated_ids.ndim == 1:
                generated_ids = generated_ids.unsqueeze(0)
            generated_ids = ensure_max_context(generated_ids, MAX_CONTEXT)
            response = token_ids_to_text(generated_ids[0], tokenizer)
            responses.append(response)

            logits = policy_model(generated_ids)
            if logits.ndim == 2:
                logits = logits.unsqueeze(0)
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = generated_ids[..., 1:].contiguous()
            logprobs_tensor = F.log_softmax(shift_logits, dim=-1)
            logprobs_for_labels = logprobs_tensor.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
            all_logprobs.append(logprobs_for_labels.mean())

            values = value_model(generated_ids)
            if values.ndim > 1:
                values = values[0, -1]
            all_values.append(values)

    return responses, torch.stack(all_logprobs), torch.stack(all_values)

In [31]:
def forward_pass(responses, policy_model, value_model, tokenizer):
    logprobs = []
    values = []
    for response in responses:
        input_ids = text_to_token_ids(response, tokenizer)
        if isinstance(input_ids, list):
            input_ids = torch.tensor(input_ids, dtype=torch.long)
        if input_ids.ndim == 1:
            input_ids = input_ids.unsqueeze(0)  # [1, seq_len]
        input_ids = input_ids.to(next(policy_model.parameters()).device)

        logits = policy_model(input_ids)  # [batch, seq_len, vocab_size] or [seq_len, vocab_size]
        if logits.ndim == 2:
            logits = logits.unsqueeze(0)  # [1, seq_len, vocab_size]
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = input_ids[..., 1:].contiguous()

        import torch.nn.functional as F
        logprobs_tensor = F.log_softmax(shift_logits, dim=-1)
        # Gather logprobs for the actual next tokens
        logprobs_for_labels = logprobs_tensor.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
        # Optionally, mean over tokens
        logprobs.append(logprobs_for_labels.mean())

        value = value_model(input_ids)
        if value.ndim > 1:
            value = value[0, -1]
        values.append(value)

    return torch.stack(logprobs), torch.stack(values)

In [32]:
def ppo_update(policy_model,
                value_model,
                policy_optimizer,
                value_optimizer,
                kl_ctl,
                stats,
                tokenizer,
                responses,
                old_logprobs,
                rewards,
                advantages,
                returns,
                config):

    batch_size = len(responses)
    indices = torch.randperm(batch_size)

    policy_losses = []
    value_losses = []
    kl_divergence = []

    for epoch in range(config.ppo_epochs):
        for i in range(0, batch_size, config.mini_batch_size):
            mini_batch_indices = indices[i:i + config.mini_batch_size]

            mini_responses = [responses[idx] for idx in mini_batch_indices]
            mini_old_logprobs = old_logprobs[mini_batch_indices]
            mini_advantages = advantages[mini_batch_indices]
            mini_returns = returns[mini_batch_indices]

            policy_model.train()
            value_model.train()

            new_logprobs, new_values = forward_pass(mini_responses, policy_model, value_model, tokenizer)
            ratio = torch.exp(new_logprobs - mini_old_logprobs)
            surr1 = ratio * mini_advantages
            surr2 = torch.clamp(ratio, 1 - config.clip_range, 1 + config.clip_range) * mini_advantages
            policy_loss = -torch.min(surr1, surr2).mean()

            value_loss = F.mse_loss(new_values, mini_returns)

            kl_div = (mini_old_logprobs - new_logprobs).mean()

            total_loss = policy_loss + config.vf_coef * value_loss + config.kl_coef * kl_div

            policy_optimizer.zero_grad()
            value_optimizer.zero_grad()
            total_loss.backward()

            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), config.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(value_model.parameters(), config.max_grad_norm)

            policy_losses.append(policy_loss.item())
            value_losses.append(value_loss.item())
            kl_divergence.append(kl_div.item())

    mean_kl = np.mean(kl_divergence)
    kl_ctl.update(mean_kl, batch_size)

    stats['policy_loss'].extend(policy_losses)
    stats['value_loss'].extend(value_losses)
    stats['kl_divergence'].extend(kl_divergence)
    stats['rewards'].extend(rewards.tolist())
    stats['advantages'].extend(advantages.tolist())

    return {
        'policy_loss': np.mean(policy_losses),
        'value_loss': np.mean(value_losses),
        'kl_divergence': mean_kl,
        'mean_reward': rewards.mean().item(),
        'mean_advantage': advantages.mean().item()
    }

In [33]:
def train_step(
    batch,
    policy_model,
    reward_model,
    value_model,
    policy_optimizer,
    value_optimizer,
    kl_ctl,
    stats,
    tokenizer,
    config

):

    prompts = batch['prompt']
    responses, old_logprobs, old_values = generate_responses(prompts, policy_model, value_model, tokenizer)
    rewards = compute_rewards(reward_model, responses, tokenizer, config)
    advantages, returns = compute_advantages(rewards, old_values, config)
    metrics = ppo_update(policy_model,
                value_model,
                policy_optimizer,
                value_optimizer,
                kl_ctl,
                stats,
                tokenizer,
                responses,
                old_logprobs,
                rewards,
                advantages,
                returns,
                config)

    return metrics


In [34]:
from tqdm import tqdm

class PPOTrainer:

    def __init__(self,
                config,
                policy_model,
                ref_model,
                reward_model,
                value_model,
                tokenizer):

        self.config = config
        self.policy_model = policy_model
        self.ref_model = ref_model
        self.reward_model = reward_model
        self.value_model = value_model
        self.tokenizer = tokenizer

        for param in self.ref_model.parameters():
            param.requires_grad = False

        self.policy_optimizer = torch.optim.Adam(
            self.policy_model.parameters(), lr=self.config.learning_rate
        )

        self.value_optimizer = torch.optim.Adam(
            self.value_model.parameters(), lr=self.config.learning_rate
        )

        self.kl_ctl = AdaptiveKLController(config.kl_coef, config.kl_target)

        self.stats = {
                    'policy_loss': [],
                    'value_loss': [],
                    'kl_divergence': [],
                    'rewards': [],
                    'advantages': []
                    }

    def train(self, dataloader, num_epochs=1):

        for epoch in range(num_epochs):
            epoch_metrics = []

            for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}")):
                metrics = train_step(
                                        batch,
                                        self.policy_model,
                                        self.reward_model,
                                        self.value_model,
                                        self.policy_optimizer,
                                        self.value_optimizer,
                                        self.kl_ctl,
                                        self.stats,
                                        self.tokenizer,
                                        self.config
                                    )
                epoch_metrics.append(metrics)

                if batch_idx % 10 == 0:
                    avg_metrics = {k: np.mean([m[k] for m in epoch_metrics[-10:]]) for k in metrics.keys()}
                    print(f"Batch {batch_idx}: {avg_metrics}")
            epoch_avg = {k: np.mean([m[k] for m in epoch_metrics]) for k in epoch_metrics[0].keys()}
            print(f"Epoch {epoch+1} Summary: {epoch_avg}")

In [35]:
@dataclass
class PPOConfig:
    """Configuration for PPO training"""
    model_name: str = "gpt2"
    learning_rate: float = 1e-6
    batch_size: int = 8
    mini_batch_size: int = 2
    gradient_accumulation_steps: int = 1
    ppo_epochs: int = 4
    max_grad_norm: float = 1.0
    clip_range: float = 0.2
    clip_range_vf: float = 0.2
    vf_coef: float = 0.1
    kl_coef: float = 0.01
    target_kl: float = 0.1
    max_length: int = 512
    temperature: float = 1.0
    top_k: int = 50
    top_p: float = 0.95
    gamma: float = 0.99
    lam: float = 0.95
    entropy_coef: float = 0.01
    adaptive_kl: bool = True
    kl_target: float = 6.0
    horizon: int = 10000
    use_score_scaling: bool = True
    use_score_norm: bool = True
    score_clip: float = 5.0

In [36]:
config = PPOConfig(
    learning_rate=1e-6,
    batch_size=8,
    ppo_epochs=4,
    max_length=1024
)

In [37]:
tokenizer = get_tokenizer()
device = "cuda" if torch.cuda.is_available() else "cpu"

In [38]:
trainer = PPOTrainer(
                    config,
                    policy_model,
                    reference_model,
                    reward_model,
                    value_model,
                    tokenizer)



In [None]:
trainer.train(train_dataloader, num_epochs=1)
torch.save(policy_model.state_dict(), "ppo_trained_model.pth")
print("Training complted and model saved")

  advantages = torch.tensor(advantages, device=device)
Epoch 1:   0%|          | 1/531 [01:48<15:58:27, 108.51s/it]

Batch 0: {'policy_loss': np.float64(0.20140691008418798), 'value_loss': np.float64(1.1066711321473122), 'kl_divergence': np.float64(1.5953607559204102), 'mean_reward': np.float64(0.0), 'mean_advantage': np.float64(0.0)}


Epoch 1:   2%|▏         | 11/531 [12:06<8:19:16, 57.61s/it] 

Batch 10: {'policy_loss': np.float64(0.11008045729249716), 'value_loss': np.float64(3.131737431138754), 'kl_divergence': np.float64(1.0589583575725556), 'mean_reward': np.float64(-1.9371509552001954e-08), 'mean_advantage': np.float64(-2.2351741790771484e-09)}


Epoch 1:   4%|▍         | 21/531 [28:36<14:53:41, 105.14s/it]

Batch 20: {'policy_loss': np.float64(0.11892779916524887), 'value_loss': np.float64(3.6807915715500714), 'kl_divergence': np.float64(0.9123098909854889), 'mean_reward': np.float64(2.2351741790771484e-09), 'mean_advantage': np.float64(-1.1175870895385742e-09)}


Epoch 1:   6%|▌         | 31/531 [53:47<34:42:02, 249.84s/it]

Batch 30: {'policy_loss': np.float64(0.07020856142044067), 'value_loss': np.float64(3.4865598507225513), 'kl_divergence': np.float64(0.8262212872505188), 'mean_reward': np.float64(-1.2665987014770507e-08), 'mean_advantage': np.float64(1.3411045074462891e-08)}


Epoch 1:   8%|▊         | 41/531 [1:16:17<20:46:39, 152.65s/it]

Batch 40: {'policy_loss': np.float64(0.10168765243142844), 'value_loss': np.float64(2.5973120304523034), 'kl_divergence': np.float64(0.7822708308696746), 'mean_reward': np.float64(5.140900611877442e-08), 'mean_advantage': np.float64(-8.568167686462402e-09)}


Epoch 1:  10%|▉         | 51/531 [1:40:16<16:55:42, 126.96s/it]

Batch 50: {'policy_loss': np.float64(0.09853548267856241), 'value_loss': np.float64(4.22994108973071), 'kl_divergence': np.float64(0.6958517163991929), 'mean_reward': np.float64(4.470348358154297e-09), 'mean_advantage': np.float64(1.2665987014770507e-08)}


Epoch 1:  11%|█▏        | 61/531 [2:05:44<22:10:30, 169.85s/it]

Batch 60: {'policy_loss': np.float64(0.1313460787758231), 'value_loss': np.float64(3.741183688491583), 'kl_divergence': np.float64(1.0680326581001283), 'mean_reward': np.float64(-8.195638656616212e-09), 'mean_advantage': np.float64(-2.4959444999694825e-08)}


Epoch 1:  13%|█▎        | 71/531 [2:26:18<10:14:12, 80.11s/it] 

Batch 70: {'policy_loss': np.float64(0.11460776887834072), 'value_loss': np.float64(4.2555758389644325), 'kl_divergence': np.float64(0.96179758310318), 'mean_reward': np.float64(1.0058283805847169e-08), 'mean_advantage': np.float64(-1.564621925354004e-08)}


Epoch 1:  15%|█▌        | 81/531 [2:52:30<20:40:28, 165.40s/it]

Batch 80: {'policy_loss': np.float64(0.09744078107178211), 'value_loss': np.float64(3.59176789005287), 'kl_divergence': np.float64(0.9229157567024231), 'mean_reward': np.float64(8.195638656616212e-09), 'mean_advantage': np.float64(1.862645149230957e-08)}


Epoch 1:  17%|█▋        | 89/531 [3:14:30<21:15:29, 173.14s/it]