In [1]:
import os
import torch
from pathlib import Path
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer
)
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from trl import PPOConfig, PPOTrainer

class PPOTrainerWithLoRA:
    """
    Trainer that wraps TRL's PPOTrainer with LoRA (PEFT) configuration.
    """
    def __init__(
        self,
        model_name: str,
        reward_model_name: str,
        train_dataset: Dataset,
        eval_dataset: Dataset,
        output_dir: str,
        lora_r: int = 8,
        lora_alpha: int = 32,
        lora_dropout: float = 0.05,
        per_device_train_batch_size: int = 4,
        per_device_eval_batch_size: int = 2,
        num_ppo_epochs: int = 3,
        seed: int = 42,
        report_to: str = "none"
    ):
        self.model_name = model_name
        self.reward_model_name = reward_model_name
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.output_dir = output_dir
        self.seed = seed

        # Load tokenizer and models
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

        self.tokenizer.chat_template = getattr(self.tokenizer, "chat_template", None)
        # Load policy and reference
        self.policy = AutoModelForCausalLM.from_pretrained(model_name)
        self.ref_policy = AutoModelForCausalLM.from_pretrained(model_name)
        # Load reward and value models
        self.reward_model = AutoModelForSequenceClassification.from_pretrained(
            reward_model_name, num_labels=1
        )
        self.value_model = AutoModelForSequenceClassification.from_pretrained(
            reward_model_name, num_labels=1
        )

        # Configure LoRA via PEFT
        self.peft_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            bias="none",
            task_type="CAUSAL_LM"
        )
        # Apply LoRA to the policy model
        self.policy = get_peft_model(self.policy, self.peft_config)

        # Set up PPOConfig
        self.ppo_config = PPOConfig(
            output_dir=self.output_dir,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=per_device_eval_batch_size,
            num_ppo_epochs=num_ppo_epochs,
            seed=self.seed,
            report_to=report_to,
        )

        # Initialize TRL PPO trainer
        self.trainer = PPOTrainer(
            args=self.ppo_config,
            processing_class=self.tokenizer,
            model=self.policy,
            ref_model=self.ref_policy,
            reward_model=self.reward_model,
            value_model=self.value_model,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            peft_config=self.peft_config,
        )

    def train(self):
        """
        Run the PPO training loop.
        """
        self.trainer.train()

    def save(self, save_directory: str = None):
        """
        Save the fine-tuned policy model and tokenizer.
        """
        target_dir = save_directory or self.output_dir
        os.makedirs(target_dir, exist_ok=True)
        # Save PEFT adapters and base model
        self.policy.save_pretrained(target_dir)
        self.tokenizer.save_pretrained(target_dir)

In [2]:
DATA_PATH = Path("../data/hellaswag_format/personal_chat_sessions_train_hellaswag.jsonl")
MIN_WORDS = 3

def load_jsonl_pydantic(path: Path):
    from shared_models import HellaSwagEntry
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            yield HellaSwagEntry.model_validate_json(line)

# Starting from your raw `Dataset`
def has_enough_words(example):
    return len(example["context"].split()) >= MIN_WORDS

data_pairs = []
for ex in load_jsonl_pydantic(DATA_PATH):
    endings = [ex.ending0, ex.ending1, ex.ending2, ex.ending3, ex.ending4]
    human_resp = endings[ex.label].strip()
    data_pairs.append({
        "context": ex.context.strip(),
        "human_resp": human_resp
    })

raw_dataset = Dataset.from_list(data_pairs)
raw_dataset = raw_dataset.filter(has_enough_words)

train_test = raw_dataset.train_test_split(test_size=0.1, seed=42)
train_ds, test_ds = train_test["train"], train_test["test"]

Filter:   0%|          | 0/22282 [00:00<?, ? examples/s]

In [3]:
BASE_MODEL_NAME = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
REWARD_MODEL_PATH = "../data/models/reward_model_ckpts_test/checkpoint-3753"
OUTPUT_MODEL_PATH = "../data/models/rlhf_ckpts"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
tokenizer.max_length =128

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    
def tokenize_fn(examples):
    return tokenizer(
        examples["context"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

train_ds = train_ds.map(tokenize_fn, batched=True, remove_columns=train_ds.column_names)
test_ds  = test_ds.map(tokenize_fn, batched=True, remove_columns=test_ds.column_names)

Map:   0%|          | 0/19955 [00:00<?, ? examples/s]

Map:   0%|          | 0/2218 [00:00<?, ? examples/s]

In [4]:
trainer = PPOTrainerWithLoRA(
    model_name=BASE_MODEL_NAME,
    reward_model_name=REWARD_MODEL_PATH,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    output_dir=OUTPUT_MODEL_PATH,
)



In [5]:
trainer.train()

===training policy===


Step,Training Loss
