# PPO training with explainable rewards

### Setup

In [1]:
import torch
from transformers import AutoTokenizer, GenerationConfig, BitsAndBytesConfig
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, setup_chat_format
from peft import LoraConfig, PeftModel
import logging
import time
# from tqdm.notebook import tqdm
from tqdm import tqdm
from datasets import load_dataset, DatasetDict
import wandb
from huggingface_hub import login
import random
import numpy as np

from src import config, utils, reward
from data import preprocess_helpsteer


for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

torch.manual_seed(config.SEED)
random.seed(config.SEED)
np.random.seed(config.SEED)

if config.HF_TOKEN:
    login(token=config.HF_TOKEN)
    logger.info("Logged into Hugging Face Hub.")
else:
    logger.warning("HF_TOKEN not found. Cannot push models to Hub.")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.
INFO - Logged into Hugging Face Hub.


### Data loading and preparation

In [2]:
logger.info("Loading and preparing RL dataset")

try:
    rl_dataset = preprocess_helpsteer.load_and_prepare_rl_dataset()
    train_ds = rl_dataset["train"].shuffle(seed=config.SEED).select(range(1000))
    val_ds = rl_dataset["test"].shuffle(seed=config.SEED).select(range(100))
    # train_ds = rl_dataset["train"]
    logger.info(f"Loaded {len(train_ds)} training prompts.")
except Exception as e:
    logger.error(f"Failed to load dataset: {e}")
    raise

INFO - Loading and preparing RL dataset
INFO - Loading dataset: nvidia/HelpSteer2
INFO - Dataset loaded: DatasetDict({
    train: Dataset({
        features: ['prompt', 'response', 'helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity'],
        num_rows: 20324
    })
    validation: Dataset({
        features: ['prompt', 'response', 'helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity'],
        num_rows: 1038
    })
})


Filter:   0%|          | 0/20324 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1038 [00:00<?, ? examples/s]

INFO - Renamed prompt column to 'query'.
INFO - RL Dataset prepared with 20324 training prompts and 1038 test prompts.
INFO - Loaded 1000 training prompts.


In [3]:
train_ds

Dataset({
    features: ['query'],
    num_rows: 1000
})

In [4]:
val_ds

Dataset({
    features: ['query'],
    num_rows: 100
})

In [5]:
val_ds[0]

{'query': 'What does OOPS as a development methodology stand for?\n<extra_id_1>Assistant\nOOPS is an acronym for Object-Oriented Programming with Nested Structures. OOPS is a software development methodology that emphasizes the use of objects, classes, and inheritance in software design. OOPS also supports the use of modular programming techniques and the divide-and-conquer algorithm for problem solving. OOPS has gained popularity as a software development methodology because it supports the creation of modular, reusable code that is well-suited for large software projects.\n<extra_id_1>User\nInvestigate any two the divide-and-conquer algorithms. Select one real life problem in which they are applicable. Implement the solution to the problem. Provide both theoretical and experimental assessment of the performance the algorithm.\n<extra_id_1>Assistant\nOne of the most common divide-and-conquer algorithms is the merge sort algorithm. The merge sort algorithm takes an array of items and r

In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL_NAME)

def tokenize_fn(examples):
    tokens = tokenizer(
        examples["query"],
        padding="max_length",
        truncation=True,
        max_length=config.MAX_SEQ_LENGTH,
    )
    return tokens

In [7]:
train_ds = train_ds.map(tokenize_fn, batched=True, remove_columns=["query"])
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
val_ds = val_ds.map(tokenize_fn, batched=True, remove_columns=["query"])
val_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

train_ds

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 1000
})

In [8]:
train_ds[0]["input_ids"]

tensor([  6023,  38971,    678,   3681,  55090, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 

In [9]:
batch = next(iter(torch.utils.data.DataLoader(train_ds, batch_size=2)))
print(type(batch["input_ids"]))

<class 'torch.Tensor'>


### Load models

#### Policy model

In [10]:
# logger.info(f"Loading policy model: {config.BASE_MODEL_NAME}")

# lora_config_ppo = LoraConfig(
#     r=config.LORA_R,
#     lora_alpha=config.LORA_ALPHA,
#     lora_dropout=config.LORA_DROPOUT,
#     target_modules=config.LORA_TARGET_MODULES,
#     bias="none",
#     task_type="CAUSAL_LM",
# )

# policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
#     config.BASE_MODEL_NAME,
#     trust_remote_code=True,
#     load_in_4bit=True,
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
#     peft_config=lora_config_ppo,
# )
# logger.info("Policy model loaded.")

In [11]:
# from transformers import GenerationConfig

# if not hasattr(policy_model, "generation_config"):
#     # You can load defaults from the base model or craft your own:
#     policy_model.generation_config = GenerationConfig.from_pretrained(
#         config.BASE_MODEL_NAME
#     )  # controls generate() behavior :contentReference[oaicite:0]{index=0}

# # 2. Tell PPO what your end‐of‐generation token is:
# # ppo_config.stop_token_id = tokenizer.eos_token_id

In [12]:
from trl.models.modeling_value_head import AutoModelForCausalLMWithValueHead

logger.info(f"Loading policy model: {config.BASE_MODEL_NAME}")

lora_config_ppo = LoraConfig(
    r=config.LORA_R,
    lora_alpha=config.LORA_ALPHA,
    lora_dropout=config.LORA_DROPOUT,
    target_modules=config.LORA_TARGET_MODULES,
    bias="none",
    task_type="CAUSAL_LM",
)

policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    config.BASE_MODEL_NAME,
    trust_remote_code=True,
    load_in_4bit=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    peft_config=lora_config_ppo,
)
logger.info("Policy model loaded.")

INFO - Loading policy model: Qwen/Qwen2.5-0.5B-Instruct
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
INFO - peft adapter initialised
INFO - Policy model loaded.


In [13]:
from transformers import GenerationConfig

policy_model.config.use_cache = False
policy_model.gradient_checkpointing_disable()

if not hasattr(policy_model, "generation_config"):
    policy_model.generation_config = GenerationConfig.from_pretrained(config.BASE_MODEL_NAME)

#### Reference model

In [14]:
# logger.info(f"Loading reference model: {config.BASE_MODEL_NAME}")

# ref_model, tokenizer = utils.load_model_and_tokenizer(
#     config.BASE_MODEL_NAME,
#     load_4bit=True,
#     add_lora=False
# )
# ref_model.eval()
# logger.info("Reference model and tokenizer loaded.")

# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token
#     policy_model.config.pad_token_id = tokenizer.eos_token_id

In [15]:
from transformers import BitsAndBytesConfig
from trl.models.modeling_value_head import AutoModelForCausalLMWithValueHead


bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    config.BASE_MODEL_NAME,
    quantization_config=bnb_cfg,
    device_map="auto",
    trust_remote_code=True
)
ref_model.eval()
ref_model.config.use_cache = False
ref_model.gradient_checkpointing_disable()

INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


#### Reward model

In [16]:
logger.info("Initializing Explainable Reward Model...")
explainable_reward = reward.ExplainableRewardModel(model_name=config.JUDGE_MODEL_NAME, device=config.DEVICE)
logger.info("Reward model initialized.")

INFO - Initializing Explainable Reward Model...
INFO - Initializing explainable RM using judge: Qwen/Qwen2.5-0.5B-Instruct on device cuda
INFO - Loading model: Qwen/Qwen2.5-0.5B-Instruct for mode: causal
INFO - 4-bit quantization enabled.
INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
INFO - Prepared model for 4-bit training.
INFO - Model and tokenizer loading complete.
INFO - Judge model loaded and set to evaluation mode.
INFO - Explainable RM initialized.
INFO - Reward model initialized.


### Value model

In [17]:
logger.info(f"Loading value model: {config.BASE_MODEL_NAME}")

from transformers import AutoModelForSequenceClassification

value_model = AutoModelForSequenceClassification.from_pretrained(
    config.BASE_MODEL_NAME,
    num_labels=1,
    trust_remote_code=True,
    load_in_4bit=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
value_model.eval()
logger.info("Value model loaded.")

INFO - Loading value model: Qwen/Qwen2.5-0.5B-Instruct
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-0.5B-Instruct and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO - Value model loaded.


### Training

In [18]:
ppo_config = PPOConfig(
    num_ppo_epochs=config.PPO_NUM_EPOCHS,
    learning_rate=config.PPO_LEARNING_RATE,
    report_to=config.LOG_WITH if config.LOG_WITH else None,
    batch_size=config.RL_BATCH_SIZE,
    mini_batch_size=config.PPO_MINI_BATCH_SIZE,
    gradient_accumulation_steps=config.PPO_GRAD_ACCUMULATION_PPO,
    stop_token_id=tokenizer.eos_token_id,
    seed=config.SEED,
    logging_dir="ppo-runs/",
)

In [19]:
ppo_trainer = PPOTrainer(
    args=ppo_config,
    model=policy_model,
    ref_model=ref_model,
    reward_model=explainable_reward,
    value_model=value_model,
    processing_class=tokenizer,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=None,
)

In [None]:
logger.info("Starting PPO training loop...")

import wandb
wandb_api = "86cd74d37ebed39035c6b54365fe1b6a76f36839"
wandb.login(key=wandb_api)

if config.LOG_WITH == "wandb":
    try:
        wandb.init(
            project="xai-ppo-explainable",
            name=f"ppo-{config.BASE_MODEL_NAME.split('/')[-1]}"
        )
        logger.info("WandB initialized.")
    except Exception as e:
        logger.error(f"Failed to initialize WandB: {e}")

INFO - Starting PPO training loop...
[34m[1mwandb[0m: Currently logged in as: [33mmiliusha2801[0m ([33mmiliusha2801-innopolis-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO - WandB initialized.


In [21]:
generation_kwargs = {
    "max_new_tokens": config.RL_MAX_NEW_TOKENS,
    "min_length": -1,
    "top_k": config.RL_TOP_K,
    "top_p": config.RL_TOP_P,
    "temperature": config.RL_TEMPERATURE,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}

In [22]:
ppo_trainer.train()

===training policy===


OutOfMemoryError: CUDA out of memory. Tried to allocate 10.23 GiB. GPU 0 has a total capacity of 8.00 GiB of which 0 bytes is free. Of the allocated memory 9.15 GiB is allocated by PyTorch, and 356.06 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)