Resources
 - https://arxiv.org/abs/2305.18290
 - https://github.com/rasbt/LLMs-from-scratch/blob/main/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb

 I highly recommend Sebastian Raschka, he is a great explainer of these concepts(transformer/GPT) and has been really helpful over the years

#### What is DPO?
An alternative to RLHF. Unlike RLHF, it doesn't train a separate reward model and has no need for policy optimization

![DPO equation](https://camo.githubusercontent.com/e23f17264853fa0c72445b5f250098258930e5238d6c766db812964564772461/68747470733a2f2f73656261737469616e72617363686b612e636f6d2f696d616765732f4c4c4d732d66726f6d2d736372617463682d696d616765732f64706f2f332e776562703f313233)

In [15]:
!mkdir -p data
!wget -O data/instruction-data-with-preference.json https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/27a6a7e64a97a07da2030ed8c291cbb3e1a4bd0a/ch07/04_preference-tuning-with-dpo/instruction-data-with-preference.json

--2024-11-08 15:24:31--  https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/27a6a7e64a97a07da2030ed8c291cbb3e1a4bd0a/ch07/04_preference-tuning-with-dpo/instruction-data-with-preference.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 386968 (378K) [text/plain]
Saving to: ‘data/instruction-data-with-preference.json’


2024-11-08 15:24:31 (21.6 MB/s) - ‘data/instruction-data-with-preference.json’ saved [386968/386968]



In [17]:
import json
import os

file_path = "data/instruction-data-with-preference.json"

with open(file_path, "r", encoding="utf-8") as file:
    data = json.load(file)

print("Number of entries:", len(data))

Number of entries: 1100


In [23]:
def format(entry):
    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that appropriately completes the request."
        f"\n\n### Instruction:\n{entry['instruction']}"
    )

    input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""

    return instruction_text + input_text

In [20]:
model_input = format(data[0])
print(model_input)

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Evaluate the following phrase by transforming it into the spelling given.

### Input:
freind --> friend


In [21]:
train_dataset = int(len(data) * 0.85)  # 85% for training
test_dataset = int(len(data) * 0.1)    # 10% for testing
val_dataset = len(data) - train_dataset - test_dataset  # Remaining 5% for validation

train_data = data[:train_dataset]
test_data = data[train_dataset:train_dataset + test_dataset]
val_data = data[train_dataset + test_dataset:]

print("Training set length:", len(train_data))
print("Validation set length:", len(val_data))
print("Test set length:", len(test_data))

Training set length: 935
Validation set length: 55
Test set length: 110


In [43]:
import torch
from torch.utils.data import Dataset

def decode(token_ids, tokenizer):
    ids_in_python_list = token_ids.flatten().tolist()
    return tokenizer.decode(ids_in_python_list)

class InstructionDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data

        # Pre-tokenize texts
        self.encoded_texts = []
        for entry in data:
            prompt = format(entry)
            rejected_response = entry["rejected"]
            chosen_response = entry["chosen"]

            prompt_tokens = tokenizer.encode(prompt)
            chosen_full_text = f"{prompt}\n\n### Response:\n{chosen_response}"
            rejected_full_text = f"{prompt}\n\n### Response:\n{rejected_response}"
            chosen_full_tokens = tokenizer.encode(chosen_full_text)
            rejected_full_tokens = tokenizer.encode(rejected_full_text)

            self.encoded_texts.append({
                "prompt": prompt_tokens,
                "chosen": chosen_full_tokens,
                "rejected": rejected_full_tokens,
            })

    def __getitem__(self, index):
        return self.encoded_texts[index]

    def __len__(self):
        return len(self.data)

In [31]:
from torch import Tensor
def collate_fn(
        batch,
        pad_token_id=50256,
        allowed_max_length = None,
        mask_prompt_tokens=True,
        device="cpu"
):
    # Initialize lists to hold batch data
    batch_data:dict[str, list[Tensor]] = {
        "prompt": [],
        "chosen": [],
        "rejected": [],
        "rejected_mask": [],
        "chosen_mask": []

    }

   # Determine the longest sequence to set a common padding length
    max_length_common = 0
    if batch:
        max_chosen = max(len(item["chosen"])+1 for item in batch)
        max_rejected = max(len(item["rejected"])+1 for item in batch)
        max_length_common = max(max_chosen, max_rejected)

        # Process each item in the batch
    for item in batch:
        prompt = torch.tensor(item["prompt"])
        batch_data["prompt"].append(prompt)

        for key in ["chosen", "rejected"]:
            # Adjust padding according to the common maximum length
            sequence = item[key]
            padded = sequence + [pad_token_id] * (max_length_common - len(sequence))
            mask = torch.ones(len(padded)).bool()

            # Set mask for all padding tokens to False
            mask[len(sequence):] = False

            # Set mask for all input tokens to False
            # +2 sets the 2 newline ("\n") tokens before "### Response" to False
            if mask_prompt_tokens:
                mask[:prompt.shape[0]+2] = False

            batch_data[key].append(torch.tensor(padded))
            batch_data[f"{key}_mask"].append(mask)


    # Final processing
    for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]:
        # Stack all sequences into a tensor for the given key
        tensor_stack = torch.stack(batch_data[key])

        # Optionally truncate to maximum sequence length
        if allowed_max_length is not None:
            tensor_stack = tensor_stack[:, :allowed_max_length]

        # Move to the specified device
        # Ignore ValueError
        batch_data[key] = tensor_stack.to(device) # type: ignore

    return batch_data

In [32]:
from functools import partial

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

custom_collate_fn = partial(
    collate_fn,
    device=device,            # Put the data directly on a GPU if available
    mask_prompt_tokens=True,  # This is optional
    allowed_max_length=1024   # The supported context length of the model
)

Device: cpu


In [None]:
%pip install tiktoken

In [52]:
import tiktoken
from torch.utils.data import DataLoader

num_workers = 0
batch_size = 8

tokenizer = tiktoken.get_encoding("gpt2")

train_dataset = InstructionDataset(train_data, tokenizer)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=2,
    collate_fn=custom_collate_fn,
    shuffle=False
)

val_dataset = InstructionDataset(val_data, tokenizer)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    collate_fn=custom_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

test_dataset = InstructionDataset(test_data, tokenizer)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=custom_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)


for batch in train_dataloader:
    print("batch.keys:", batch.keys())
    print("\n", "-"*8, "\n", "PROMPT TOKENS:\n", decode(batch["prompt"][0], tokenizer))
    print("\n", "-"*8, "\n", "CHOSEN TOKENS:\n", decode(batch["chosen"][0][batch["chosen_mask"][0]], tokenizer))
    print("\n", "-"*8, "\n", "REJECTED TOKENS:\n", decode(batch["rejected"][0][batch["rejected_mask"][0]], tokenizer))
    break


batch.keys: dict_keys(['prompt', 'chosen', 'rejected', 'rejected_mask', 'chosen_mask'])

 -------- 
 PROMPT TOKENS:
 Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Evaluate the following phrase by transforming it into the spelling given.

### Input:
freind --> friend

 -------- 
 CHOSEN TOKENS:
 ### Response:
The spelling of the given phrase "freind" is incorrect, the correct spelling is "friend".

 -------- 
 REJECTED TOKENS:
 ### Response:
The spelling of the given phrase "freind" is flat out wrong, get it together, the correct spelling is "friend".


In [54]:
import os
from pathlib import Path
import shutil


finetuned_model_path = Path(data / "gpt2-medium355M-sft.pth")
if not finetuned_model_path.exists():
    # download model ln 31

model = GPTModel(BASE_CONFIG)
model.load_state_dict(
    torch.load(
        finetuned_model_path,
        map_location=torch.device("cpu"),
        weights_only=True
    )
)
model.eval();

SyntaxError: incomplete input (<ipython-input-54-0df93b3addc4>, line 8)

In [55]:
import torch.nn.functional as F

def compute_dpo_loss(
      model_chosen_logprobs,
      model_rejected_logprobs,
      reference_chosen_logprobs,
      reference_rejected_logprobs,
      beta=0.1,
    ):
    """Compute the DPO loss for a batch of policy and reference model log probabilities.

    Args:
        policy_chosen_logprobs: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
        policy_rejected_logprobs: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
        reference_chosen_logprobs: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
        reference_rejected_logprobs: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
        beta: Temperature parameter for the DPO loss; typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
        label_smoothing: conservativeness for DPO loss.

    Returns:
        A tuple of three tensors: (loss, chosen_rewards, rejected_rewards).
    """

    model_logratios = model_chosen_logprobs - model_rejected_logprobs
    reference_logratios = reference_chosen_logprobs - reference_rejected_logprobs
    logits = model_logratios - reference_logratios

    # DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
    losses = -F.logsigmoid(beta * logits)

    # Optional values to track progress during training
    chosen_rewards = (model_chosen_logprobs - reference_chosen_logprobs).detach()
    rejected_rewards = (model_rejected_logprobs - reference_rejected_logprobs).detach()

    # .mean() to average over the samples in the batch
    return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()