In [None]:
import os
import copy
import wandb
import huggingface_hub as hub

from datasets import load_dataset

from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)

from accelerate import PartialState

from trl import ModelConfig, PPOConfig, ScriptArguments
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

from utils import PolicyCommutator, CustomPPOTrainer


In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]  = "2"
os.environ["WANDB_PROJECT"]         = "Pythia-FedPPO"
os.environ["WANDB_ENTITY"]          = "RADFAN"

In [None]:
wandb.login()
hub.login(token="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")

In [None]:
NUM_EPOCHS = 2
NUM_AGENTS = 2
COMMUTANT  = [
    [0.8, 0.2],
    [0.2, 0.8]
]

In [None]:
###############################################################################
# Configs
###############################################################################

script_args = ScriptArguments(
    dataset_name        = "trl-internal-testing/descriptiveness-sentiment-trl-style",
    dataset_train_split = "descriptiveness",
)

# Model to use for policies
# =============================================================================

model_config = ModelConfig(
    model_name_or_path  = "EleutherAI/pythia-70m-deduped",
    trust_remote_code   = False,
)

# PPO trainer config
# =============================================================================

ppo_config = PPOConfig(
    dataset_num_proc    = 1,
    num_ppo_epochs      = 1,
    num_train_epochs    = 0.05,
    num_mini_batches    = 1,
    learning_rate       = 3e-6,
    missing_eos_penalty = 1.0,
    per_device_train_batch_size       = 1,
    gradient_accumulation_steps       = 16,
    local_rollout_forward_batch_size  = 1,
    reward_model_path   = "EleutherAI/pythia-70m-deduped",
    exp_name            = "Pythia-70M",
    output_dir          = "Pythia-70M",
    hub_model_id        = "RLHF-And-Friends/FedPPO-Pythia-70M",
    push_to_hub         = True,
)

# Distinct PPO configs
# =============================================================================

ppo_configs = [copy.copy(ppo_config) for _ in range(NUM_AGENTS)]

for agent_idx, config in enumerate(ppo_configs):
    config.exp_name     = f"{config.exp_name}-a{agent_idx}"
    config.output_dir   = f"{config.output_dir}-a{agent_idx}"
    config.hub_model_id = f"{config.hub_model_id}-a{agent_idx}"


In [None]:
###############################################################################
#  Tokenizer
###############################################################################

tokenizer = AutoTokenizer.from_pretrained(
    model_config.model_name_or_path,
    padding_side="left",
    trust_remote_code=model_config.trust_remote_code,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

if tokenizer.chat_template is None:
    tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE


In [None]:
###############################################################################
#  Models
###############################################################################

sft_model = AutoModelForCausalLM.from_pretrained(
    model_config.model_name_or_path
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
    ppo_config.reward_model_path,
    num_labels=1
)

policy_models = [
    AutoModelForCausalLM.from_pretrained(
        model_config.model_name_or_path
    )
    for _ in range(NUM_AGENTS)
]
value_models = [
    AutoModelForSequenceClassification.from_pretrained(
        model_config.model_name_or_path, 
        num_labels=1
    )
    for _ in range(NUM_AGENTS)
]

reference_models = PolicyCommutator(
    policies = policy_models,
    commutant = COMMUTANT
)

In [None]:
###############################################################################
#  Dataset
###############################################################################

dataset = load_dataset(
    script_args.dataset_name,
    split=script_args.dataset_train_split
)
eval_samples = 100
train_dataset = dataset.select(range(len(dataset) - eval_samples))
eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset)))
dataset_text_field = "prompt"

def prepare_dataset(dataset, tokenizer):
    """
    pre-tokenize the dataset before training; only collate during training
    """

    def tokenize(element):
        outputs = tokenizer(
            element[dataset_text_field],
            padding=False,
        )
        return {"input_ids": outputs["input_ids"]}

    return dataset.map(
        tokenize,
        batched=True,
        remove_columns=dataset.column_names
    )

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
    train_dataset = prepare_dataset(train_dataset, tokenizer)
    eval_dataset = prepare_dataset(eval_dataset, tokenizer)
    
train_datasets = [
    train_dataset.shard(num_shards = NUM_AGENTS, index = agent_idx)
    for agent_idx in range(NUM_AGENTS)
]

In [None]:
###############################################################################
#  Training
###############################################################################

for epoch in range(NUM_EPOCHS):
    for idx in range(NUM_AGENTS):
        trainer = CustomPPOTrainer(
            config            = ppo_configs[idx],
            processing_class  = tokenizer,
            policy            = policy_models[idx],
            ref_policy        = reference_models[idx],
            reward_model      = reward_model,
            value_model       = value_models[idx],
            train_dataset     = train_datasets[idx],
            eval_dataset      = eval_dataset,
        )
        trainer.train()
        wandb.finish()

        if ppo_configs[idx].push_to_hub:
            trainer.push_to_hub(dataset_name=script_args.dataset_name)
