<a href="https://www.kaggle.com/code/aisuko/supervised-fine-tuning-llama2-with-dpo?scriptVersionId=165210041" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

RLHF(Reinforcement Learning from Human Feedback) can help to ensure that the language model's output are aligned with human expectations such as chattiness or safety features. However,it also brings some of the complexity of RL into NLP. Some like, we need to build a good reward function, train the model to estimate the values of a state, and at the same time be careful not to strive too far from the original model and produce gibberish instead of sensibel text. Such a process is quite involved requiring a number of complex moving parts where it is not always easy to get things right.

In [Direct Preference Optimization paper](https://arxiv.org/abs/2305.18290) peoposes to cast the RL-based objective used by existing methods to an objective which can be directly optimized via a simple binary cross-entropy loss which simplifies this process of refining LLMs greatly.

We are going to use [DPO](https://www.kaggle.com/code/aisuko/fine-tune-llm-with-direct-preference-optimization#Direct-Preference-Optimization) fine-tune Llama v2 7B on the one of the [**Preference datasets**](https://www.kaggle.com/code/aisuko/fine-tune-llm-with-direct-preference-optimization#Preparing-datasets). Here we use [stack_exchaneg preference](https://huggingface.co/datasets/lvwerra/stack-exchange-paired) dataset.

In [None]:
%%capture
!pip install transformers==4.36.2
!pip install accelerate==0.25.0
!pip install datasets==2.15.0
!pip install peft==0.7.1
!pip install bitsandbytes==0.41.3
!pip install trl==0.7.7
!pip install tqdm==4.66.1

In [None]:
import os
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

login(token=user_secrets.get_secret("HUGGINGFACE_TOKEN"))

os.environ["WANDB_API_KEY"]=user_secrets.get_secret("WANDB_API_KEY")
os.environ["WANDB_PROJECT"] = "Fine-tuning Llama2-with-stack-exchange"
os.environ["WANDB_NOTES"] = "Fine tune model distilbert base uncased"
os.environ["WANDB_NAME"] = "ft-Llama2-with-stack-exchange-paired"

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_name="meta-llama/Llama-2-7b-hf"
tokenizer=AutoTokenizer.from_pretrained(model_name,trust_remote_code=True)
tokenizer.pad_token=tokenizer.eos_token
tokenizer.padding_side="right" # this fixed the weird overflow issue with fp16 training

# Preparing the Datasets

In [None]:
from datasets import load_dataset


def prepare_sample_text(example):
    text=f"Question:{example['question']}\n\nAnswer:{example['response_j']}"
    return text

def chars_token_ratio(dataset, tokenizer, nb_examples=400):
    """
    Estimate the average number of characters per token in the dataset.
    """
    total_characters, total_tokens=0,0
    for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
        text=prepare_sample_text(example)
        total_characters+=len(text)
        if tokenizer.is_fast:
            total_tokens+=len(tokenizer(text).tokens())
        else:
            total_tokens+=len(tokenizer.tokenize(text))
    return total_characters/total_tokens


streaming=True

dataset=load_dataset(
    path="lvwerra/stack-exchange-paired",
    data_dir="data/finetune", # the subset to use
    split="train",
    num_proc=4,
    streaming=streaming,
)

dataset

In [None]:
if streaming:
    valid_data=dataset.take(4000)
    train_data=dataset.skip(4000)
    train_data=train_data.shuffle(buffer_size=5000, seed=None)
else:
    dataset=dataset.train_test_split(test_size=0.005, seed=None)
    train_data=dataset["train"]
    valid_data=dataset["test"]

chars_per_token=chars_token_ratio(train_data, tokenizer)
chars_per_token

In [None]:
from trl.trainer import ConstantLengthDataset

train_dataset=ConstantLengthDataset(
    tokenizer,
    train_data,
    formatting_func=prepare_sample_text,
    infinite=True,
    seq_length=1024,
    chars_per_token=chars_per_token,
)

valid_dataset=ConstantLengthDataset(
    tokenizer,
    valid_data,
    formatting_func=prepare_sample_text,
    infinite=False,
    seq_length=1024,
    chars_per_token=chars_per_token,
)

## Quantize the model in FP4

Reducing the GPU memory usage in loading models and inference processes.

In [None]:
import torch
from peft import AutoPeftModelForCausalLM


bnb_config=BitsAndBytesConfig(
    # It is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from bitsandbytes.
    load_in_4bit=True,
    # This sets the quantization data type in the bnb.nn.Linear4Bit layers.
    bnb_4bit_quant_type="nf4",
    # Support nested quantization
    bnb_4bit_use_double_quant=True,
    # This sets the computational type which might be different than the input time
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model=AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    trust_remote_code=False,
)

base_model.config.use_cache=False
print(base_model)

In [None]:
from peft LoraConfig, TaskType

peft_config=LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj","v_proj"],
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

peft_confg

In [None]:
from trl import SFTTrainer


training_args=TrainingArguments(
    output_dir="./sft",
    max_steps=100,
    logging_steps=10,
    save_steps=10,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    group_by_length=False,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_steps=50,
    weight_decay=0.05,
    optim="paged_adamw_32bit",
    fp16=True,
    remove_unused_columns=False,
    run_name=os.getenv("WANDB_NAME"),
    report_to="wandb"
)

sft_trainer=SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args,
)

sft_trainer.train()

## Save the merge model

In [None]:
kwargs={
    'model_name': f'{os.getenv("WANDB_NAME")}',
    'finetuned_from': 'meta-llama/Llama-2-7b-hf',
#     'tasks': '',
#     'dataset_tags':'',
    'dataset':'lvwerra/stack-exchange-paired'
}

tokenizer.push_to_hub(os.getenv("WANDB_NAME"))
sft_trainer.push_to_hub(**kwargs)

In [None]:
import gc

del sft_trainer, base_model
gc.collect()
torch.cuda.empty_cache()

In [None]:
model=AutoPeftModelForcausalLM.from_pretrained("./sft/final_checkpoint", device_map="auto", torch_dtype=torch.bfloat16)
model=model.merge_and_unload()

model.save_pretrained("./sft/final_merged_checkpoint", safe_serialization=True)

# Direct Preference Optimization

In [None]:
def return_prompt_and_responses(samples)-> Dict[str,str]:
    return {
        "prompt":[
            "Question:"+question+"\n\nAnswer:" for question in samples["question"]
        ],
        "chosen": samples["response_j"],
        "rejected": samples["response_k"],
    }


def get_stack_exchange_paired(data_dir="data/rl",sanity_check=False,cache_dir=None,num_proc=24):
    dataset=load_dataset(
        "lvwerra/stack-exchange-paired",
        split="train",
        data_dir="data/rl",
        cache_dir=cache_dir,
    )
    original_columns=dataset.column_names
    
    if sanity_check:
        dataset=dataset.select(range(min(len(dataset), 1000)))
    
    return dataset.map(
        return_prompt_and_responses,
        batched=True,
        num_proc=num_proc,
        remove_columns=original_columns,
    )

Check [Direct Preference Optimization](https://www.kaggle.com/code/aisuko/fine-tuning-mistral-7b-with-dpo?scriptVersionId=158111896#Direct-Preference-Optimization) to get more information about model and reference model

In [None]:
model=AutoModelForCausalLM.from_pretrained(
    "./sft/final_merged_checkpoint",
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)

model.config.use_cache=False

model_ref=AutoModelForCausalLM.from_pretrained(
    "./sft/final_merged_checkpoint",
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)

In [None]:
tokenizer_dpo=AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer_dpo.pad_token=tokenizer.eos_token

In [None]:
## Load the Stack-exchange paired dataset
train_dataset=get_stack_exchange_paired(
    data_dir="data/rl",
    sanity_check=False
)

train_dataset=train_dataset.filter(
    lambda x: len(x["prompt"])+len(x["chosen"])<=1024 and len(x["prompt"])+len(x["rejected"])<=1024
)

eval_dataset=get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True)
eval_dataset=eval_dataset.filter(
    lambda x: len(x["prompt"])+len(x["chosen"])<=1024 and len(x["prompt"])+len(x["rejected"])<=1024
)

# Initialize training arguments

In [None]:
training_args=TrainingArguments(
    per_device_train_batch_size=4,
    per_device_eval_batch_size=1,
    max_steps=1000,
    logging_steps=10,
    save_steps=100,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=5e-4,
    evaluation_strategy="steps",
    eval_steps=100,
    output_dir=os.getenv("WANDB_NAME"),
    report_to="wandb",
    lr_scheduler_type="cosine",
    warmup_steps=100,
    optim="paged_adamw_32bit",
    bf16=True,
    remove_unused_columns=False,
    run_name=os.getenv("WANDB_NAME"),
)

peft_config=LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['q_proj','v_proj','k_proj','out_proj','fc_in','fc_out','wte',],
    bias="none",
    task_type="CAUSAL_LM",
)

dpo_trainer=DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
    max_prompt_length=512,
    max_length=1024,
)

dpo_trainer.train()
dpo_trainer.save_model(os.getenv("WANDB_NAME"))

In [None]:
# save the model
dpo_trainer.model.save_pretrained(os.getenv("WANDB_NAME")+"/final_checkpoint")

# Reference list

* https://huggingface.co/blog/dpo-trl