In [1]:
import os
import copy
import wandb

from datasets import load_dataset

from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
from trl import (
    ModelConfig,
    PPOConfig,
    ScriptArguments,
    get_quantization_config,
    get_peft_config,
)
from peft import (
    PeftModelForSequenceClassification,
    TaskType, 
    get_peft_model,
)
from accelerate import PartialState

from fed_ppo.utils import (
    PolicyCommutator,
    CustomPPOTrainer,
    frozen_copy,
    tokenize_as_chat
)

In [2]:
BASE_MODEL = "LLama-3.2-1B-Instruct"
NUM_EPOCHS = 3
NUM_AGENTS = 2
COMMUTANT  = [
    [0.8, 0.2],
    [0.2, 0.8]
]

In [3]:
os.environ["CUDA_DEVICE_ORDER"]     = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]  = "0,1"
os.environ["WANDB_PROJECT"]         = f"FedPPO-{BASE_MODEL}"
os.environ["WANDB_ENTITY"]          = "RADFAN"

In [4]:
###############################################################################
# Configs
###############################################################################

script_args = ScriptArguments(
    dataset_name        = "HuggingFaceH4/ultrachat_200k",
    dataset_train_split = "train_gen",
    dataset_test_split  = "test_gen",
)

# Model
# =============================================================================

model_config = ModelConfig(
    model_name_or_path   = f"meta-llama/{BASE_MODEL}",
    # LoRA
    # ---------------------------------------------------------------------------------------------
    use_peft             = True,
    lora_r               = 8,
    lora_alpha           = 16,
    lora_dropout         = 0.0,
    lora_task_type       = TaskType.CAUSAL_LM,
    lora_target_modules  = ["q_proj", "k_proj", "v_proj", "o_proj"],
    # Quantization
    # ---------------------------------------------------------------------------------------------
    load_in_8bit         = False,
    load_in_4bit         = False,
    torch_dtype          = "bfloat16",
)

# Value model
# =================================================================================================

value_model_config = ModelConfig(
    # LoRA
    # ---------------------------------------------------------------------------------------------
    use_peft            = True,
    lora_r              = 8,
    lora_alpha          = 16,
    lora_dropout        = 0.0,
    lora_task_type      = TaskType.SEQ_CLS,
    lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
    # Quantization
    # ---------------------------------------------------------------------------------------------
    load_in_8bit        = False,
    load_in_4bit        = False,
    torch_dtype         = "bfloat16",
)

# Reward model
# =================================================================================================

reward_model_config = ModelConfig(
    model_name_or_path  = "RLHF-And-Friends/Llama-3.2-1B-Instruct-Reward-LoRA8r",
    use_peft            = True,
    load_in_8bit        = False,
    load_in_4bit        = False,
)

# PPO trainer config
# =================================================================================================

ppo_config = PPOConfig(
    # Common
    # ---------------------------------------------------------------------------------------------
    exp_name            = f"FedPPO-{BASE_MODEL}",
    output_dir          = f"FedPPO-{BASE_MODEL}",
    dataset_num_proc    = 16,
    num_mini_batches    = 1,
    learning_rate       = 1e-5,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 8,
    num_train_epochs    = 1,
    response_length     = 512,
    # Push to hub after training
    # ---------------------------------------------------------------------------------------------
    push_to_hub         = True,
    hub_model_id        =f"RLHF-And-Friends/FedPPO-{BASE_MODEL}",
    # On-policy params
    # ---------------------------------------------------------------------------------------------
    missing_eos_penalty = 1.0,
    local_rollout_forward_batch_size = 1,
    # PPO params
    # ---------------------------------------------------------------------------------------------
    num_ppo_epochs      = 1,
    whiten_rewards      = False,
    kl_coef             = 0.05,
    cliprange           = 0.2,
    vf_coef             = 0.1,
    cliprange_value     = 0.2,
    gamma               = 1.0,
    lam                 = 0.95,
)

# Distinct PPO configs
# =================================================================================================

ppo_configs = [copy.copy(ppo_config) for _ in range(NUM_AGENTS)]

for agent_idx, config in enumerate(ppo_configs):
    config.exp_name     = f"{config.exp_name}-A{agent_idx}"
    config.output_dir   = f"{config.output_dir}-A{agent_idx}"
    config.hub_model_id = f"{config.hub_model_id}-A{agent_idx}"

In [None]:
# Tokenizer
# -------------------------------------------------------------------------------------------------

tokenizer = AutoTokenizer.from_pretrained(
    model_config.model_name_or_path,
    use_fast = True,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

# SFT
# -------------------------------------------------------------------------------------------------

sft = AutoModelForCausalLM.from_pretrained(
    model_config.model_name_or_path,
    quantization_config = get_quantization_config(model_config),
    torch_dtype=model_config.torch_dtype,
)
sft.resize_token_embeddings(len(tokenizer), mean_resizing=False)
sft.config.pad_token_id = tokenizer.pad_token_id

# Policy Models
# -------------------------------------------------------------------------------------------------

policy_models = [
    get_peft_model(
        sft,
        get_peft_config(model_config)
    )
    for _ in range(NUM_AGENTS)
]

# Base Value Model
# -------------------------------------------------------------------------------------------------

base_value_model = AutoModelForSequenceClassification.from_pretrained(
    model_config.model_name_or_path,
    num_labels = 1,
    quantization_config = get_quantization_config(value_model_config),
    torch_dtype=model_config.torch_dtype,
)
base_value_model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
base_value_model.config.pad_token_id = tokenizer.pad_token_id

# Reward model
# -------------------------------------------------------------------------------------------------

reward_model = PeftModelForSequenceClassification.from_pretrained(
    base_value_model,
    reward_model_config.model_name_or_path,
    num_labels = 1,
    quantization_config = get_quantization_config(reward_model_config)
)

# Value models
# -------------------------------------------------------------------------------------------------

value_models = [
    get_peft_model(
        base_value_model,
        get_peft_config(value_model_config)
    )
    for _ in range(NUM_AGENTS)
]

In [6]:
###############################################################################
#  Dataset
###############################################################################

train = load_dataset(
    script_args.dataset_name,
    split=script_args.dataset_train_split
).select(range(1000))

eval = load_dataset(
    script_args.dataset_name,
    split=script_args.dataset_test_split
).select(range(100))


# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
    train_dataset = tokenize_as_chat(train, tokenizer)
    eval_dataset = tokenize_as_chat(eval, tokenizer)

train_datasets = [
    train_dataset.shard(num_shards = NUM_AGENTS, index = agent_idx)
    for agent_idx in range(NUM_AGENTS)
]

In [None]:
###############################################################################
#  Training
###############################################################################

for epoch in range(NUM_EPOCHS):

    frozen_policies = [frozen_copy(model) for model in policy_models]

    reference_models = PolicyCommutator(
        policies = frozen_policies,
        commutant = COMMUTANT
    )
    for idx in range(NUM_AGENTS):
        trainer = CustomPPOTrainer(
            args              = ppo_configs[idx],
            processing_class  = tokenizer,
            model             = policy_models[idx],
            ref_model         = reference_models[idx],
            reward_model      = reward_model,
            value_model       = value_models[idx],
            train_dataset     = train_datasets[idx],
            eval_dataset      = eval_dataset,
        )
        trainer.train()
        wandb.finish()

        if ppo_configs[idx].push_to_hub:
            trainer.push_to_hub(dataset_name=script_args.dataset_name)
