In [1]:
pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [2]:
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Note: you need to install transformers from main to run this script. See https://huggingface.co/docs/transformers/installation#install-from-source
# TODO: bump transformers version in requirements at next release.

# 0. imports

from typing import Dict
from dotenv import load_dotenv
import os
from huggingface_hub import login

import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig

from trl import DPOTrainer


load_dotenv()

beta = 0.1
model_name_or_path = "meta-llama/Llama-2-7b-hf"
learning_rate = 1e-4
per_device_train_batch_size = 1
gradient_accumulation_steps = 1
max_length = 512
max_prompt_length = 384
max_target_length = 128
label_pad_token_id = -100
max_steps = 3000
sanity_check = False
report_to = None
ignore_bias_buffers = False
gradient_checkpointing = False
gradient_checkpointing_kwargs = None
access_token = os.getenv("HF_ACCESS_TOKEN")

login(token=access_token)


def extract_anthropic_prompt(prompt_and_response):
    """Extract the anthropic prompt from a prompt and response pair."""
    search_term = "\n\nAssistant:"
    search_term_idx = prompt_and_response.rfind(search_term)
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
    return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
    """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """
    dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def split_prompt_and_responses(sample) -> Dict[str, str]:
        prompt = extract_anthropic_prompt(sample["chosen"])
        return {
            "prompt": prompt,
            "chosen": sample["chosen"][len(prompt) :],
            "rejected": sample["rejected"][len(prompt) :],
        }

    return dataset.map(split_prompt_and_responses)

def print_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print("Trainable params: ", trainable_params)
    print("Total params: ", total_params)
    print(f"Percentage of trainable params: {trainable_params/total_params*100}%")


if __name__ == "__main__":
    # 0. PEFT and Quantization config
    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    # 1. load a pretrained model
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path, # location of saved SFT model
        quantization_config=nf4_config,
        use_auth_token=True,
    )

    if ignore_bias_buffers:
        # torch distributed hack
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        use_auth_token=True,
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # 2. Load the Anthropic Helpful-Harmless dataset
    train_dataset = get_hh("train", sanity_check=False)

    # 3. Load evaluation dataset
    eval_dataset = get_hh("test", sanity_check=True)

    # 4. initialize training arguments:
    training_args = TrainingArguments(
        per_device_train_batch_size=per_device_train_batch_size,
        max_steps=max_steps,
        remove_unused_columns=False,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        evaluation_strategy="steps",
        logging_first_step=True,
        logging_steps=10,  # match results in blog post
        eval_steps=1000,
        output_dir="./test",
        warmup_steps=150,
        report_to=report_to,
        bf16=True,
        gradient_checkpointing=gradient_checkpointing,
        # TODO: uncomment that on the next transformers release
        # gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
    )

    # 5. initialize the DPO trainer
    dpo_trainer = DPOTrainer(
        model,
        args=training_args,
        beta=beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_length=max_length,
        max_target_length=max_target_length,
        max_prompt_length=max_prompt_length,
        peft_config=peft_config,
    )
    print_parameters(dpo_trainer.model)

    # 6. train
    dpo_trainer.train()

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /home/ec2-user/.cache/huggingface/token
Login successful




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using pad_token, but it is not set yet.


Trainable params:  8388608
Total params:  3508801536
Percentage of trainable params: 0.23907331075678143%


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/rejected,Logps/chosen,Logits/rejected,Logits/chosen
1000,0.6663,0.789456,-0.36744,-0.253819,0.424,-0.113621,-108.934616,-92.948509,0.221753,0.168724
2000,0.7525,1.08103,-2.471128,-2.681504,0.515,0.210377,-133.211472,-113.985374,-0.069076,-0.118976
3000,0.9581,1.023924,-2.108772,-2.029161,0.471,-0.079611,-126.688042,-110.361824,-0.079663,-0.128623


In [3]:
prompt = "Hey, what's up?"
inputs = tokenizer(f"\n\nHuman: {prompt}\n\nAssistant:", return_tensors='pt')
outputs = dpo_trainer.model.generate(**inputs, max_new_tokens=50)
tokenizer.decode(outputs[0])



"<s> \n\nHuman: Hey, what's up?\n\nAssistant: Hi! Thanks for asking! Great to hear from you! A lot has happened since we last spoke, so here's a quick recap of what's new:\n* We launched a new product that you might be interested in.\n"

In [4]:
prompt = "Hey, what's up?"
inputs = tokenizer(prompt, return_tensors='pt')
outputs = dpo_trainer.model.generate(**inputs, max_new_tokens=50)
tokenizer.decode(outputs[0])

"<s> Hey, what's up? My name is Mike and I'm a 20 year old college student who is majoring in computer science. sierp 2017 ... When a person is looking to meet someone new, they're going to use dating apps"