In [None]:
# For google colab
# !pip install trl
# !pip install -U bitsandbytes
# !pip install -U accelerate
# !pip install -U accelerate

In [None]:
from huggingface_hub import login

# put your huggingface token----
login("Huggingface_token")

In [None]:
# 1. Import packages
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer, DPOConfig
from peft import LoraConfig
from accelerate import Accelerator
from dataclasses import dataclass
import os

In [None]:
# 2. Training arguments
@dataclass
class ScriptArguments:
    model_name_or_path: str = "meta-llama/Llama-2-7b-hf"
    learning_rate: float = 5e-5
    beta: float = 0.1
    per_device_train_batch_size: int = 1
    per_device_eval_batch_size: int = 1
    gradient_accumulation_steps: int = 4
    gradient_checkpointing: bool = True
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    max_length: int = 1024
    max_prompt_length: int = 512
    max_steps: int = 100
    save_steps: int = 50
    eval_steps: int = 50
    logging_steps: int = 10
    output_dir: str = "./dpo_llama"
    report_to: str = "none"
    load_in_4bit: bool = True
    model_dtype: str = "float16"
    seed: int = 42

args = ScriptArguments()

In [None]:
# 3. Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# 4. Load model
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float": torch.float}[args.model_dtype]
model = AutoModelForCausalLM.from_pretrained(
    args.model_name_or_path,
    torch_dtype=torch_dtype,
    load_in_4bit=args.load_in_4bit,
    device_map={"": Accelerator().local_process_index},
)
model.config.use_cache = False

In [None]:
# 5. Load and preprocess dataset (Anthropic hh-rlhf)
def split_prompt_response(example):
    # Assumes format like: "Human: ...\n\nAssistant: ..."
    sep = "\n\nAssistant:"
    if sep in example["chosen"] and sep in example["rejected"]:
        prompt = example["chosen"].split(sep)[0] + sep
        chosen = example["chosen"].split(sep)[1].strip()
        rejected = example["rejected"].split(sep)[1].strip()
        return {
            "prompt": prompt,
            "chosen": chosen,
            "rejected": rejected,
        }
    else:
        return {
            "prompt": "",
            "chosen": "",
            "rejected": ""
        }

def preprocess_dataset(split):
    dataset = load_dataset("Anthropic/hh-rlhf", split=split)
    dataset = dataset.map(split_prompt_response)
    dataset = dataset.filter(
        lambda x: len(x["prompt"] + x["chosen"]) <= args.max_length
                  and len(x["prompt"] + x["rejected"]) <= args.max_length
                  and x["prompt"] != ""
    )
    return dataset


# Use small subsets for faster experimentation
train_dataset = preprocess_dataset("train[:1%]")
eval_dataset = preprocess_dataset("train[1%:2%]")

In [None]:
# 6. Setup DPO Config
training_args = DPOConfig(
    per_device_train_batch_size=args.per_device_train_batch_size,
    per_device_eval_batch_size=args.per_device_eval_batch_size,
    max_steps=args.max_steps,
    logging_steps=args.logging_steps,
    save_steps=args.save_steps,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    gradient_checkpointing=args.gradient_checkpointing,
    learning_rate=args.learning_rate,
    eval_strategy="steps",
    eval_steps=args.eval_steps,
    output_dir=args.output_dir,
    beta=args.beta,
    report_to=args.report_to,
    lr_scheduler_type="cosine",
    warmup_steps=10,
    optim="paged_adamw_32bit",
    bf16=torch_dtype == torch.bfloat16,
    remove_unused_columns=False,
    run_name="dpo_llama_colab",
)

peft_config = LoraConfig(
    r=args.lora_r,
    lora_alpha=args.lora_alpha,
    lora_dropout=args.lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

In [None]:
# 7. Train using DPOTrainer
dpo_trainer = DPOTrainer(
    model=model,
    ref_model=None,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=tokenizer,
    peft_config=peft_config,
)

In [None]:
# 8. Train and Save
dpo_trainer.train()
dpo_trainer.save_model(os.path.join(args.output_dir, "final_checkpoint"))