### Direct Preference Optimization (DPO)

Direct Preference Optimziation is an approach to improve the alignment of a language model. Current approaches for training large language models utilize an intial step of unsupervised training on a large corpus of data. This results in a model that generates language most likely to follows some prompt given the data that the model is trained on. When we want out language model to be used as a chat bot or code assistant, this oftentimes produces undesirable text. High quality conversations or coding examples may be rare in our training corpus, thus rare in our model's output.

To address this, we can utilize a dataset containing three columns: a prompt, a chosen output, and a rejected output.

It is an alternative to reinforcement learning from human feedback (RLHF).

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, DataCollatorForLanguageModeling
from trl import DPOTrainer
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import DataCollatorForLanguageModeling

# note: to get bitsandbytes to work on windows, uninstall bitsandbytes and reinstall with
# pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl

In [2]:
def process_hh_rlhf_sample(sample):
    """
    sample is a dictionary with keys 'chosen' and 'rejected'.
    Extract the prompt and the two completions from the sample.
    Find the index of the last substring '\n\nAssistant:

    Return a dictionary with keys prompt, chosen, and rejected.
    """
    term = '\n\nAssistant: '
    end_of_prompt_index = sample['chosen'].rfind(term)

    # extract the prompt
    prompt = sample['chosen'][:end_of_prompt_index+len(term)]
    # extract the chosen completion
    chosen = sample['chosen'][len(prompt):]
    # extract the rejected completion
    rejected = sample['rejected'][len(prompt):]

    return {'prompt': prompt, 'chosen': chosen, 'rejected': rejected}

def get_anthropic_hh_rlhf_dataset(split='train'):
    """
    The Anthropic HH-RLHF dataset contains 160k training examples and 8k test examples.
    Each example is a dictionary with two keys: 'chosen' and 'rejected'.
    Each of these includes the prompt and the completion.
    I want to extract the prompt, chosen completion, and rejected completion.

    https://arxiv.org/abs/2204.05862
    https://huggingface.co/datasets/Anthropic/hh-rlhf
    """
    dataset = load_dataset('Anthropic/hh-rlhf', split=split)
    return dataset.map(process_hh_rlhf_sample)

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True
)

torch_dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
    'stabilityai/stablelm-2-1_6b',
    quantization_config=bnb_config,
    torch_dtype=torch_dtype,
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    'stabilityai/stablelm-2-1_6b',
    trust_remote_code=True
)

# https://github.com/huggingface/trl/issues/1073
tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token})
tokenizer.bos_token_id = tokenizer.eos_token_id

train_dataset = get_anthropic_hh_rlhf_dataset(split='train')
test_dataset = get_anthropic_hh_rlhf_dataset(split='test[:1000]') # use a small test set for now

# define the training arguments
training_args = TrainingArguments(
    max_steps=64, # only 64 gradient updates, not even one epoch
    remove_unused_columns=False,
    learning_rate=1e-4,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    output_dir='output',
    logging_strategy='steps',
    logging_dir='logs',
    logging_steps=16,
    lr_scheduler_type='constant' # default is linear
)

# after training, I can view the logs with:

peft_config = LoraConfig(
    r=64, # dimension of the low-rank matrices
    lora_alpha=16, # scaling factor for the weight matrices
    bias='none', # don't train bias params
    task_type='CASUAL_LM',
    target_modules=[
        'q_proj',
        'k_proj',
        'v_proj',
        'o_proj',
        'gate_proj',
        'up_proj',
        'down_proj',
        'lm_head',
    ]
)

tokenizer.pad_token = tokenizer.eos_token
model = get_peft_model(model, peft_config)

bin c:\Users\danto\anaconda3\lib\site-packages\bitsandbytes\libbitsandbytes_cuda118.dll


In [3]:
def tokenize_func(examples):
    return tokenizer(examples['prompt'], examples['chosen'], examples['rejected'], padding=True, truncation=True)

encoded_dataset_train = train_dataset.map(tokenize_func, batched=True)
encoded_dataset_test = test_dataset.map(tokenize_func, batched=True)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = DPOTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    beta=0.1, # beta param for DPO loss
    train_dataset=encoded_dataset_train,
    eval_dataset=encoded_dataset_test,
    max_length=512,
    max_target_length=128,
    max_prompt_length=128,
    generate_during_eval=False,
    peft_config=peft_config
)



In [4]:
# evaluate the model before training
print(trainer.evaluate()) # evaluation will take about 30 minutes

100%|██████████| 125/125 [13:36<00:00,  6.53s/it] 

{'eval_loss': 0.6931473612785339, 'eval_runtime': 818.5613, 'eval_samples_per_second': 1.222, 'eval_steps_per_second': 0.153, 'eval_rewards/chosen': 0.0, 'eval_rewards/rejected': 0.0, 'eval_rewards/accuracies': 0.0, 'eval_rewards/margins': 0.0, 'eval_logps/rejected': -108.19891357421875, 'eval_logps/chosen': -87.72126770019531, 'eval_logits/rejected': -1.9339709281921387, 'eval_logits/chosen': -2.053931951522827}





In [5]:
# rewards/chosen is the mean difference between the log probabilities of the policy model and 
# the reference model for the chosen completion, scaled by beta.
# rewards/rejected is the mean difference between the log probabilities of the policy model and 
# the reference model for the rejected completion, scaled by beta.
# rewards/accuracy is the mean of how often the chosen reward is greater than the corresponding rejected reward.
# rewards/margins is the mean difference between the chosen and corresponding rejected rewards

trainer.train()

 25%|██▌       | 16/64 [33:08<3:01:16, 226.60s/it]

{'loss': 1.0057, 'learning_rate': 0.001, 'rewards/chosen': -1.0836848020553589, 'rewards/rejected': -0.8774189949035645, 'rewards/accuracies': 0.40625, 'rewards/margins': -0.2062658816576004, 'logps/rejected': -164.83995056152344, 'logps/chosen': -152.500244140625, 'logits/rejected': -2.065371036529541, 'logits/chosen': -2.0074422359466553, 'epoch': 0.0}


 50%|█████     | 32/64 [56:36<47:22, 88.81s/it]   

{'loss': 0.8486, 'learning_rate': 0.001, 'rewards/chosen': -1.3901917934417725, 'rewards/rejected': -1.674965739250183, 'rewards/accuracies': 0.578125, 'rewards/margins': 0.2847740054130554, 'logps/rejected': -139.48309326171875, 'logps/chosen': -131.59625244140625, 'logits/rejected': -2.0241663455963135, 'logits/chosen': -1.993961215019226, 'epoch': 0.0}


 75%|███████▌  | 48/64 [2:32:48<1:06:29, 249.36s/it] 

{'loss': 0.9214, 'learning_rate': 0.001, 'rewards/chosen': -1.5270169973373413, 'rewards/rejected': -1.4379708766937256, 'rewards/accuracies': 0.5, 'rewards/margins': -0.08904620260000229, 'logps/rejected': -131.55795288085938, 'logps/chosen': -154.04161071777344, 'logits/rejected': -2.503689765930176, 'logits/chosen': -2.3918697834014893, 'epoch': 0.0}


100%|██████████| 64/64 [3:03:35<00:00, 172.12s/it]  

{'loss': 1.0459, 'learning_rate': 0.001, 'rewards/chosen': -2.09584379196167, 'rewards/rejected': -2.092801809310913, 'rewards/accuracies': 0.46875, 'rewards/margins': -0.003042057156562805, 'logps/rejected': -148.13485717773438, 'logps/chosen': -136.04318237304688, 'logits/rejected': -2.208425760269165, 'logits/chosen': -2.2577285766601562, 'epoch': 0.0}
{'train_runtime': 11015.4092, 'train_samples_per_second': 0.023, 'train_steps_per_second': 0.006, 'train_loss': 0.9553848505020142, 'epoch': 0.0}





TrainOutput(global_step=64, training_loss=0.9553848505020142, metrics={'train_runtime': 11015.4092, 'train_samples_per_second': 0.023, 'train_steps_per_second': 0.006, 'train_loss': 0.9553848505020142, 'epoch': 0.0})

In [7]:
trainer.save_model('out')
eval_results = trainer.evaluate() # evaluation will take about 30 minutes
print(eval_results)

100%|██████████| 125/125 [04:01<00:00,  1.93s/it]

{'eval_loss': 1.1580718755722046, 'eval_runtime': 243.056, 'eval_samples_per_second': 4.114, 'eval_steps_per_second': 0.514, 'eval_rewards/chosen': -3.0772223472595215, 'eval_rewards/rejected': -3.5384018421173096, 'eval_rewards/accuracies': 0.5139999985694885, 'eval_rewards/margins': 0.46117931604385376, 'eval_logps/rejected': -143.58291625976562, 'eval_logps/chosen': -118.49349975585938, 'eval_logits/rejected': -2.3071069717407227, 'eval_logits/chosen': -2.416668653488159, 'epoch': 0.0}





In [10]:
# to view the training loss after each step, use:

trainer.state.log_history

[{'loss': 1.0057,
  'learning_rate': 0.001,
  'rewards/chosen': -1.0836848020553589,
  'rewards/rejected': -0.8774189949035645,
  'rewards/accuracies': 0.40625,
  'rewards/margins': -0.2062658816576004,
  'logps/rejected': -164.83995056152344,
  'logps/chosen': -152.500244140625,
  'logits/rejected': -2.065371036529541,
  'logits/chosen': -2.0074422359466553,
  'epoch': 0.0,
  'step': 16},
 {'loss': 0.8486,
  'learning_rate': 0.001,
  'rewards/chosen': -1.3901917934417725,
  'rewards/rejected': -1.674965739250183,
  'rewards/accuracies': 0.578125,
  'rewards/margins': 0.2847740054130554,
  'logps/rejected': -139.48309326171875,
  'logps/chosen': -131.59625244140625,
  'logits/rejected': -2.0241663455963135,
  'logits/chosen': -1.993961215019226,
  'epoch': 0.0,
  'step': 32},
 {'loss': 0.9214,
  'learning_rate': 0.001,
  'rewards/chosen': -1.5270169973373413,
  'rewards/rejected': -1.4379708766937256,
  'rewards/accuracies': 0.5,
  'rewards/margins': -0.08904620260000229,
  'logps/reje

In [12]:
import pandas as pd

# view as records

df = pd.DataFrame(trainer.state.log_history)
df

Unnamed: 0,loss,learning_rate,rewards/chosen,rewards/rejected,rewards/accuracies,rewards/margins,logps/rejected,logps/chosen,logits/rejected,logits/chosen,...,eval_samples_per_second,eval_steps_per_second,eval_rewards/chosen,eval_rewards/rejected,eval_rewards/accuracies,eval_rewards/margins,eval_logps/rejected,eval_logps/chosen,eval_logits/rejected,eval_logits/chosen
0,1.0057,0.001,-1.083685,-0.877419,0.40625,-0.206266,-164.839951,-152.500244,-2.065371,-2.007442,...,,,,,,,,,,
1,0.8486,0.001,-1.390192,-1.674966,0.578125,0.284774,-139.483093,-131.596252,-2.024166,-1.993961,...,,,,,,,,,,
2,0.9214,0.001,-1.527017,-1.437971,0.5,-0.089046,-131.557953,-154.041611,-2.50369,-2.39187,...,,,,,,,,,,
3,1.0459,0.001,-2.095844,-2.092802,0.46875,-0.003042,-148.134857,-136.043182,-2.208426,-2.257729,...,,,,,,,,,,
4,,,,,,,,,,,...,,,,,,,,,,
5,,,,,,,,,,,...,3.334,0.417,-3.077222,-3.538402,0.514,0.461179,-143.582916,-118.4935,-2.307107,-2.416669
6,,,,,,,,,,,...,4.114,0.514,-3.077222,-3.538402,0.514,0.461179,-143.582916,-118.4935,-2.307107,-2.416669


#### use gpu for inference

In [14]:
device = torch.device('cuda')
model = model.to(device)
prompt = 'Some popular cities for tourists are'
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
input_ids = input_ids.to(device)
output_ids = model.generate(input_ids, max_length=128, do_sample=True, num_return_sequences=1)
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:100257 for open-end generation.


Some popular cities for tourists are New York, Las Vegas, Los Angeles, Chicago, Washington, D.C., and Boston. But many choose to travel to unique, smaller cities, and even to places very far from the coast, like Buffalo, Syracuse, Rochester, and Niagara Falls, in upstate New York. Many people are drawn to these cities because of their more walkable downtowns, their smaller sizes, and their greater diversity of people, shops and restaurants.
