In [1]:
import os
import warnings
import torch

from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser

from trl import (
    ModelConfig,
    RewardConfig,
    RewardTrainer,
    ScriptArguments,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
    setup_chat_format,
)

os.environ["CUDA_VISIBLE_DEVICES"] = "6, 7"

os.environ["WANDB_PROJECT"] = "Pythia-PPO"
os.environ["WANDB_ENTITY"] = "RADFAN"

In [2]:
# =============================================================================
# Configs
# =============================================================================

model_config = ModelConfig(
    model_name_or_path  = "EleutherAI/pythia-70m",
    trust_remote_code   = False,
    # load_in_8bit        = False,
    # load_in_4bit        = False,
    use_peft            = True,
    lora_task_type      = "SEQ_CLS",
    lora_r              = 8,
    lora_alpha          = 8,

)
training_args = RewardConfig(
    output_dir                  = "RLHF-And-Friends/Pythia-70M-Reward",
    run_name                    = "RLHF-And-Friends/Pythia-70M-Reward-LoRA8-lr-4",
    per_device_train_batch_size = 8,
    num_train_epochs            = 1,
    # gradient_checkpointing      = True,
    # gradient accumulation steps = 1,
    learning_rate               = 1.0e-4,
    logging_steps               = 25,
    eval_strategy               = "steps",
    eval_steps                  = 125,
    max_length                  = 2048,
    # center_reward_coefficients = None,
    push_to_hub                 = True,
    hub_model_id                = "RLHF-And-Friends/Pythia-70M-Reward"
)
script_args = ScriptArguments(
    dataset_name          = "trl-lib/ultrafeedback_binarized",
    # dataset_train_split = "train",
    # dataset_test_split  = "test",
)

In [3]:
# =============================================================================
# Model & Tokenizer
# =============================================================================

torch_dtype = (
    model_config.torch_dtype
    if model_config.torch_dtype in ["auto", None]
    else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
    revision=model_config.model_revision,
    device_map=get_kbit_device_map() if quantization_config is not None else None,
    quantization_config=quantization_config,
    use_cache=False if training_args.gradient_checkpointing else True,
    torch_dtype=torch_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
)
model = AutoModelForSequenceClassification.from_pretrained(
    model_config.model_name_or_path, num_labels=1, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
# Align padding tokens between tokenizer and model
model.config.pad_token_id = tokenizer.pad_token_id

# If post-training a base model, use ChatML as the default template
if tokenizer.chat_template is None:
    model, tokenizer = setup_chat_format(model, tokenizer)

if model_config.use_peft and model_config.lora_task_type != "SEQ_CLS":
    warnings.warn(
        "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
        " Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT."
    )

Some weights of GPTNeoXForSequenceClassification were not initialized from the model checkpoint at EleutherAI/pythia-70m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# =============================================================================
# Load dataset
# =============================================================================

dataset       = load_dataset(script_args.dataset_name)
train_dataset = dataset[script_args.dataset_train_split]
eval_dataset  = (
    dataset[script_args.dataset_test_split] 
    if training_args.eval_strategy != "no" 
    else None
)

In [None]:
# =============================================================================
# Training
# =============================================================================

trainer = RewardTrainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=get_peft_config(model_config),
)

trainer.train()

In [None]:
# =============================================================================
# Save model and push to Hub
# =============================================================================

trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
    trainer.push_to_hub(dataset_name=script_args.dataset_name)

In [None]:
# =============================================================================
# Evaluate Model
# =============================================================================

if training_args.eval_strategy != "no":
    metrics = trainer.evaluate()
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)