In [1]:
import os

from datasets import load_dataset

from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)

from peft import (
    PeftModelForSequenceClassification,
    TaskType, 
    get_peft_model
)

from trl import (
    ModelConfig,
    PPOConfig,
    PPOTrainer,
    get_peft_config,
    get_quantization_config,
)

from accelerate import PartialState

from fed_ppo.utils import prepare_ppo_dataset

### Devices

In [2]:
# Visible devices
# -------------------------------------------------------------------------------------------------
VISIBLE_DEVICES = "2"
# -------------------------------------------------------------------------------------------------

# Enumerate GPUs based on their PCI bus IDs
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

os.environ["CUDA_VISIBLE_DEVICES"] = f"{VISIBLE_DEVICES}"

### Models & Dataset

In [3]:
# Policy model path
# =================================================================================================
POLICY_PATH = "meta-llama/Llama-3.2-1B-Instruct"
# =================================================================================================
POLICY_NAME = POLICY_PATH.split('/')[1]

# Reward model path
# =================================================================================================
REWARD_PATH = "RLHF-And-Friends/Llama-3.2-1B-Instruct-Reward-LoRA8r"
# =================================================================================================
REWARD_NAME = REWARD_PATH.split('/')[1]

# Prompts dataset path
# =================================================================================================
DATASET_PATH        = "HuggingFaceH4/ultrachat_200k"
DATASET_TRAIN_SPLIT = "train_gen"
DATASET_VAL_SPLIT   = "test_gen"
# =================================================================================================
DATASET_NAME        = DATASET_PATH.split('/')[1]

### WandB settings

In [4]:
os.environ["WANDB_PROJECT"] = f"{POLICY_NAME}-PPO-{DATASET_NAME}"
os.environ["WANDB_ENTITY"] = "RADFAN"

### Models' configs

In [5]:
# Policy
# =================================================================================================

policy_model_config = ModelConfig(
    model_name_or_path   = POLICY_PATH,
    # LoRA
    # ---------------------------------------------------------------------------------------------
    use_peft             = True,
    lora_r               = 8,
    lora_alpha           = 16,
    lora_dropout         = 0.0,
    lora_task_type       = TaskType.CAUSAL_LM,
    lora_target_modules  = ["q_proj", "k_proj", "v_proj", "o_proj"],
    # Quantization
    # ---------------------------------------------------------------------------------------------
    load_in_8bit         = False,
    load_in_4bit         = False,
    torch_dtype          = "bfloat16",
)

# Value model
# =================================================================================================

value_model_config = ModelConfig(
    # LoRA
    # ---------------------------------------------------------------------------------------------
    use_peft            = True,
    lora_r              = 8,
    lora_alpha          = 16,
    lora_dropout        = 0.0,
    lora_task_type      = TaskType.SEQ_CLS,
    lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
    # Quantization
    # ---------------------------------------------------------------------------------------------
    load_in_8bit        = False,
    load_in_4bit        = False,
    torch_dtype         = "bfloat16",
)

# Reward model
# =================================================================================================

reward_model_config = ModelConfig(
    model_name_or_path  = REWARD_PATH,
    use_peft            = True,
    load_in_8bit        = False,
    load_in_4bit        = False,
)

### PPO Trainer config

In [6]:
ppo_config = PPOConfig(
    # Common
    # ---------------------------------------------------------------------------------------------
    run_name            = f"LoRA-{policy_model_config.lora_r}",
    output_dir          = f"{os.environ['WANDB_PROJECT']}-LoRA-{policy_model_config.lora_r}",
    dataset_num_proc    = 16,
    num_mini_batches    = 1,
    learning_rate       = 1e-5,
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 4,
    num_train_epochs    = 1,
    response_length     = 512,
    # Push to hub after training
    # ---------------------------------------------------------------------------------------------
    push_to_hub         = True,
    hub_model_id        = f"RLHF-And-Friends/{POLICY_NAME}-PPO-{DATASET_NAME}"
                          f"-LoRA-{policy_model_config.lora_r}",

    # On-policy params
    # ---------------------------------------------------------------------------------------------
    missing_eos_penalty = 1.0,
    local_rollout_forward_batch_size = 1,

    # PPO params
    # ---------------------------------------------------------------------------------------------
    num_ppo_epochs      = 1,
    whiten_rewards      = False,
    kl_coef             = 0.05,
    cliprange           = 0.2,
    vf_coef             = 0.1,
    cliprange_value     = 0.2,
    gamma               = 1.0,
    lam                 = 0.95,
)

### Initialize models and tokenizer

In [7]:
# Tokenizer
# -------------------------------------------------------------------------------------------------

tokenizer = AutoTokenizer.from_pretrained(
    policy_model_config.model_name_or_path,
    use_fast = True,
)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})


# SFT model
# -------------------------------------------------------------------------------------------------

sft_policy = AutoModelForCausalLM.from_pretrained(
    policy_model_config.model_name_or_path,
    quantization_config = get_quantization_config(policy_model_config)
)
sft_policy.resize_token_embeddings(len(tokenizer), mean_resizing=False)
sft_policy.config.pad_token_id = tokenizer.pad_token_id

# Trainable policy
# -------------------------------------------------------------------------------------------------

if policy_model_config.use_peft:
    policy = get_peft_model(sft_policy, get_peft_config(policy_model_config))
else:
    policy = AutoModelForCausalLM.from_pretrained(
        policy_model_config.model_name_or_path
    )

# Base model for Value and Reward models
# -------------------------------------------------------------------------------------------------

base_value_head_model = AutoModelForSequenceClassification.from_pretrained(
    policy_model_config.model_name_or_path,
    num_labels = 1,
    quantization_config = get_quantization_config(value_model_config)
)
base_value_head_model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
base_value_head_model.config.pad_token_id = tokenizer.pad_token_id

# Value model with LoRA
# -------------------------------------------------------------------------------------------------

if value_model_config.use_peft:
    value_model = get_peft_model(
        base_value_head_model, get_peft_config(value_model_config))
else:
    value_model = base_value_head_model

# Reward model
# -------------------------------------------------------------------------------------------------

if reward_model_config.use_peft:
    reward_model = PeftModelForSequenceClassification.from_pretrained(
        base_value_head_model,
        reward_model_config.model_name_or_path,
        num_labels = 1,
        quantization_config = get_quantization_config(reward_model_config)
    )
else:
    reward_model = AutoModelForSequenceClassification.from_pretrained(
        reward_model_config.model_name_or_path,
        num_labels = 1,
        quantization_config = get_quantization_config(reward_model_config)
    )

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-3.2-1B-Instruct and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Initialize dataset

In [8]:
train_dataset = load_dataset(DATASET_PATH, split=DATASET_TRAIN_SPLIT).select(range(1000))
eval_dataset = load_dataset(DATASET_PATH, split=DATASET_VAL_SPLIT).select(range(100))

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
    train_dataset = prepare_ppo_dataset(train_dataset, tokenizer)
    eval_dataset = prepare_ppo_dataset(eval_dataset, tokenizer)


### Training

In [None]:
trainer = PPOTrainer(
    args            = ppo_config,
    processing_class  = tokenizer,
    model             = policy,
    ref_model         = sft_policy,
    reward_model      = reward_model,
    value_model       = value_model,
    train_dataset     = train_dataset,
    eval_dataset      = eval_dataset,
)

trainer.train()

### Save model

In [None]:
trainer.save_model(ppo_config.output_dir)
if ppo_config.push_to_hub:
    trainer.push_to_hub(dataset_name=DATASET_PATH)