### Code Reference:

- Hugging Face: https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2/scripts
- Helpful Notebook: 
    - https://colab.research.google.com/drive/12dVqXZMIVxGI0uutU6HG9RWbWPXL3vts?usp=sharing#scrollTo=yWEM89A48NrU
    - https://colab.research.google.com/drive/1PEQyJO1-f6j0S_XJ8DV50NkpzasXkrzd?usp=sharing#scrollTo=OJXpOgBFuSrc
    
- Interesting Reads:
    - https://www.assemblyai.com/blog/how-rlhf-preference-model-tuning-works-and-how-things-may-go-wrong/ 
    - https://huyenchip.com/2023/05/02/rlhf.html
    - https://lightning.ai/pages/community/lora-insights/?utm_medium=social&utm_source=twitter&utm_campaign=Education_10132023

### Imports

In [None]:
import os
import torch
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
from datasets import load_dataset
from transformers import TrainingArguments
from peft import LoraConfig, get_peft_model
from trl import DPOTrainer

CACHE_DIR = os.getcwd()+'/cache'

### Set Hugging Face Token

In [1]:
token = ###

from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

### Comparison Data Set-Up

#### SHP

In [None]:
def get_dataset(test=False):

    if test:
        dataset = load_dataset("stanfordnlp/shp", cache_dir=CACHE_DIR, split="test")
    else:
        dataset = load_dataset("stanfordnlp/shp", cache_dir=CACHE_DIR, split="train")

    original_columns = dataset.column_names
    
    def return_prompt_and_responses(samples):
        if samples['labels'] == 1:
            return {
                "prompt": ["###Question:\n"+ question + "\n\n###Answer:\n" for question in samples['question']],
                "chosen": samples["human_ref_A"],
                "rejected": samples["human_ref_B"],
            }
        else:
            return {
                "prompt": ["###Question:\n"+ question + "\n\n###Answer:\n" for question in samples['history']],
                "chosen": samples["human_ref_B"],
                "rejected": samples["human_ref_A"],
            }
        
    return dataset.map(
        return_prompt_and_responses,
        batched=True,
        remove_columns=original_columns,
    )    

#### Anthropic

In [4]:
def ha_to_qa(samples):
    return {"chosen": samples["chosen"].replace("Human:", "###Question:\n").replace("Assistant:", "###Answer:\n"), \
            "rejected": samples["rejected"].replace("Human:", "###Question:\n").replace("Assistant:", "###Answer:\n")}

def get_dataset(test=False):

    if test:
        dataset = load_dataset("Anthropic/hh-rlhf", cache_dir=CACHE_DIR, split="test")
    else:
        dataset = load_dataset("Anthropic/hh-rlhf", cache_dir=CACHE_DIR, split="train")        
    
    dataset = dataset.map(ha_to_qa)    
    
    def return_prompt_and_responses(samples):
        return {
            "prompt": [question[:question.rfind("Answer:\n")+len("Answer:\n")] for question in samples['chosen']],
            "chosen": [question[question.rfind("Answer:\n")+len("Answer:\n"):] for question in samples['chosen']],
            "rejected": [question[question.rfind("Answer:\n")+len("Answer:\n"):] for question in samples['rejected']],
        }
        
    return dataset.map(
        return_prompt_and_responses,
        batched=True,
    )    

### Load SFT Model for DPO

In [5]:
model_name = "sft/llama2_sft_orca_1024/merged_model"
#"sft/llama2_sft_lima_2epochs/merged_model"
#"sft/llama2_sft_shp"
#"meta-llama/Llama-2-7b-hf"

# load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    cache_dir=CACHE_DIR,
)
model.config.use_cache = False

# load reference model (Same as original) - loaded to avoid too much divergence
model_ref = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    cache_dir=CACHE_DIR,
)

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained("sft/llama2_sft_orca_1024", cache_dir=CACHE_DIR)
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Load Dataset for DPO

In [9]:
# Constraint text to max length of 1024
MAX_LENGTH = 1024

train_dataset = get_dataset()

train_dataset = train_dataset.filter(
    lambda x: len(x["prompt"]) + len(x["chosen"]) <= MAX_LENGTH
    and len(x["prompt"]) + len(x["rejected"]) <= MAX_LENGTH
)

In [10]:
eval_dataset = get_dataset(test=True)

eval_dataset = eval_dataset.filter(
    lambda x: len(x["prompt"]) + len(x["chosen"]) <= MAX_LENGTH
    and len(x["prompt"]) + len(x["rejected"]) <= MAX_LENGTH
)

### Training Arguments

In [11]:
# Initialize training arguments:
training_args = TrainingArguments(
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 1,
#     max_steps = 1000,
    num_train_epochs = 1,
    logging_steps = 5000,
    save_steps = 5000,
    gradient_accumulation_steps = 4,
    gradient_checkpointing = 4,
    learning_rate = 5e-4,
    evaluation_strategy = "steps",
    eval_steps = 5000,
    output_dir = "sft",
    lr_scheduler_type = "cosine",
    warmup_steps = 100,
    optim = "paged_adamw_32bit",
    bf16 = True,
    remove_unused_columns = False,
    run_name = "dpo_llama2_shp",
)

### Peft Config

In [12]:
# Set-up PEFT
peft_config = LoraConfig(
    r = 8,
    lora_alpha = 16,
    lora_dropout = 0.05,
    target_modules=[
        "q_proj",
        "v_proj",
#         "k_proj",
#         "out_proj",
#         "fc_in",
#         "fc_out",
#         "wte",
    ],
    bias = "none",
    task_type = "CAUSAL_LM",
)

### DPO Trainer

In [13]:
# initialize the DPO trainer
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args = training_args,
    beta = 0.1,
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,
    tokenizer = tokenizer,
    peft_config = peft_config,
    max_prompt_length = MAX_LENGTH,
    max_length = MAX_LENGTH,
)

# Train
dpo_trainer.train()    

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myashsharma0906[0m. Use [1m`wandb login --relogin`[0m to force relogin


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/rejected,Logps/chosen,Logits/rejected,Logits/chosen
5000,0.7231,0.699168,-0.276585,-0.951571,0.636544,0.674986,-124.498116,-122.243507,-1.06754,-1.024356


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



### Save Model for Evaluation

In [15]:
dpo_trainer.save_model("dpo/llama2_sft_orca_1024")
dpo_trainer.model.save_pretrained("dpo/llama2_sft_orca_1024/final_checkpoint")