# DPO Training

### Imports and initial setup

Code adapted from: https://www.philschmid.de/dpo-align-llms-in-2024-with-trl

In [None]:
import json
import copy

from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, \
    TrainingArguments, Conversation, pipeline
from peft import AutoPeftModelForCausalLM, LoraConfig
from trl import DPOTrainer
import torch
import multiprocessing
from datasets import Dataset
import matplotlib.pyplot as plt

from huggingface_hub import notebook_login

notebook_login()

In [None]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left' # to prevent errors with FA
tokenizer.truncation_side = 'left' # to prevent cutting off last generation

system_msg = {"role": "system", "content": "You are an expert professor, teaching a student how to solve a problem by providing a full explanation of the solution."}


## Data

##### Setup

In [None]:
with open("datasets/M1_preference_data_15052024.json", "r") as f:
    data = json.load(f)

print(len(data))

In [None]:
dpo_dataset_dict = {
    "prompt": [],
    "chosen": [],
    "rejected": [],
}

for dp in tqdm(data):
    
    qn = dp["question_complete"]
    
    for pref in dp["preference"]:
        assert pref["overall"] in ["A", "B"]
        
        msg_qn = {"role": "user", "content": qn}
        msg_chosen = {"role": "assistant", "content": pref[pref["overall"]]}
        msg_rejected = {"role": "assistant", "content": pref["A" if pref["overall"] == "B" else "B"]}
        
        dpo_dataset_dict["prompt"].append([system_msg, msg_qn])
        dpo_dataset_dict["chosen"].append([msg_chosen])
        dpo_dataset_dict["rejected"].append([msg_rejected])

In [None]:
with open("datasets/input_dpo_dataset.json", "w") as f:
    json.dump(dpo_dataset_dict, f)

In [None]:
dpo_ds = Dataset.from_dict(dpo_dataset_dict)

def process(row):
    row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False)
    row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
    row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
    return row

ds = dpo_ds.map(
    process,
    num_proc=multiprocessing.cpu_count(),
    load_from_cache_file=False,
)

ds.to_json("datasets/dpo_hf_dataset.json")

##### Analysis

In [None]:
ds_prompt_tokens = ds.map(lambda examples: tokenizer(examples['prompt']), batched=True)
ds_chosen_tokens = ds.map(lambda examples: tokenizer(examples['chosen']), batched=True)
ds_rejected_tokens = ds.map(lambda examples: tokenizer(examples['rejected']), batched=True)

In [None]:
prompt_lens = [len(x['input_ids']) for x in ds_prompt_tokens]
print(max(prompt_lens))

chosen_lens = [len(x['input_ids']) for x in ds_chosen_tokens]
print(max(chosen_lens))

rejected_lens = [len(x['input_ids']) for x in ds_rejected_tokens]
print(max(rejected_lens))

In [None]:
plt.hist([prompt_lens, chosen_lens, rejected_lens], 50, histtype="bar")
plt.legend(["Prompt", "Chosen", "Rejected"])
plt.xlim([0, 1000])
plt.savefig("data_dist.pdf")

##### Prep for training

In [None]:
ds = Dataset.from_json("datasets/dpo_hf_dataset.json")
ds = ds.train_test_split(test_size=0.1)

train_dataset = ds["train"]
eval_dataset = ds["test"]

In [None]:
#### COMMENT IN TO RECALCULATE MAX LENGTHS ####
# from numpy import percentile
 
# # lets find the p95 length of the prompt
# prompt_length = int(percentile([len(tokenizer(x)["input_ids"]) for x in train_dataset["prompt"]], 95))
# max_seq_length_chosen = int(percentile([len(tokenizer(x["prompt"] + x["chosen"])["input_ids"]) for x in train_dataset], 95))
# max_seq_length_rejected = int(percentile([len(tokenizer(x["prompt"] + x["rejected"])["input_ids"]) for x in train_dataset], 95))
# max_seq_length = max(max_seq_length_chosen, max_seq_length_rejected)
 
# # filter datasets to remove samples that are too long
# train_dataset = train_dataset.filter(lambda x: len(tokenizer(x["prompt"] + x["chosen"])["input_ids"]) <= max_seq_length)
# eval_dataset = eval_dataset.filter(lambda x: len(tokenizer(x["prompt"] + x["chosen"])["input_ids"]) <= max_seq_length)
# print(f"len(train_dataset): {len(train_dataset)}")
# print(f"len(eval_dataset): {len(eval_dataset)}")
 
# # Up the lengths to next multiple of 2, why 2? Don't know
# prompt_length = ((prompt_length + 1) // 2) * 2
# max_seq_length = ((max_seq_length + 1) // 2) * 2
# print(f"p95 prompt length: {prompt_length}")
# print(f"p95 prompt + chosen length: {max_seq_length}")
 
prompt_length = 402
max_seq_length = 912

## Training

##### Model definition

In [None]:
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    use_cache=False,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
        lora_alpha=128,
        lora_dropout=0.05,
        r=256,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)

##### `DPOTrainer` setup

In [None]:
args = TrainingArguments(
    output_dir="llama3",#"doplhin-dpo",               # directory to save and repository id
    num_train_epochs=1,                     # number of training epochs
    per_device_train_batch_size=4,         # batch size per device during training
    per_device_eval_batch_size=4,           # batch size for evaluation
    gradient_accumulation_steps=1,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    learning_rate=5e-5,                     # 10x higher LR than QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.1,                       # warmup ratio based on QLoRA paper
    lr_scheduler_type="cosine",             # use cosine learning rate scheduler
    logging_steps=25,                       # log every 25 steps
    save_steps=500,                         # when to save checkpoint
    save_total_limit=2,                     # limit the total amount of checkpoints
    evaluation_strategy="steps",            # evaluate every 1000 steps
    eval_steps=700,                         # when to evaluate
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    push_to_hub=False,                      # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
)
 
dpo_args = {
    "beta": 0.1,                            # The beta factor in DPO loss. Higher beta means less divergence
    "loss_type": "sigmoid"                  # The loss type for DPO.
}

In [None]:
trainer = DPOTrainer(
    model,
    ref_model=None, # set to none since we use peft
    peft_config=peft_config,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    max_length=max_seq_length,
    max_prompt_length=prompt_length,
    beta=dpo_args["beta"],
    loss_type=dpo_args["loss_type"],
)

##### Training process

In [None]:
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()
 
# save model at the end of training
trainer.save_model()

In [None]:
# free the memory again
del model
del trainer
torch.cuda.empty_cache()

## Testing

In [None]:
# Path to saved peft adapter model
peft_model_id = "./llama3"

# Load Model with PEFT adapter
model = AutoPeftModelForCausalLM.from_pretrained(
  peft_model_id,
  device_map="auto",
  torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)

# load into pipeline
pipe = pipeline("conversational", model=model, tokenizer=tokenizer)

##### Test MCQ capabilities

In [None]:
# 21, 63, 90 (works well), 196 (works ok, interesting failure case), 

conversation = Conversation()
conversation.add_message(system_msg)
conversation.add_message({"role": "user", "content": eval_dataset[196]["prompt"].split("<|end_header_id|>\n\n")[-1].split("<|eot_id|>")[0]})

conversation = pipe(conversation, max_new_tokens=2048, do_sample=True, temperature=1.0, top_k=50, top_p=0.9, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
conversation

In [None]:
conversation.add_message({"role": "user", "content": "Now, based off of the explanation you have given, give the correct answer as a single letter only, e.g. `A`, `B`, `C` or `D`."})

conversation_answer = pipe(copy.deepcopy(conversation), max_new_tokens=16, do_sample=True, temperature=1.0, top_k=50, top_p=0.9, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
conversation_answer.messages[-1]["content"]