# DPO Implementation using GPT-2

## Model

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")



In [3]:
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))

In [4]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50258, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50258, bias=False)
)

## Dataset

In [5]:
from datasets import load_dataset

ds = load_dataset("jondurbin/py-dpo-v0.1")

In [6]:
ds

DatasetDict({
    train: Dataset({
        features: ['prompt', 'chosen', 'rejected', 'id'],
        num_rows: 9466
    })
})

In [7]:
type(ds['train']['prompt'])

list

In [8]:
train_dataset = ds['train'].select(range(500)) # smaller subset for faster training

In [9]:
len(train_dataset)

500

In [10]:
train_dataset

Dataset({
    features: ['prompt', 'chosen', 'rejected', 'id'],
    num_rows: 500
})

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

In [12]:
class PreferenceDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, tokenizer, max_length=512):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        prompt = item['prompt']
        chosen = item['chosen']
        rejected = item['rejected']

        # Encode input_ids without padding first to get the original length
        input_ids = self.tokenizer.encode(prompt, add_special_tokens=False, truncation=True, max_length=self.max_length)
        input_length = len(input_ids)

        # Encode and pad all sequences to max_length
        input_ids = self.tokenizer.encode(prompt, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt').squeeze(0)
        chosen_ids = self.tokenizer.encode(prompt + chosen, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt').squeeze(0)
        rejected_ids = self.tokenizer.encode(prompt + rejected, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt').squeeze(0)

        return {
            'input_ids': input_ids,
            'chosen_ids': chosen_ids,
            'rejected_ids': rejected_ids,
            'input_length': input_length,
        }

In [13]:
preference_dataset = PreferenceDataset(train_dataset, tokenizer)
train_loader = DataLoader(preference_dataset, batch_size=2, shuffle=True)

## Train

In [14]:
def dpo_loss(chosen_logits, rejected_logits, attention_mask, beta=0.1):
    # Only consider non-padded tokens
    chosen_logits = chosen_logits[attention_mask.bool()]
    rejected_logits = rejected_logits[attention_mask.bool()]
    
    logits_diff = chosen_logits - rejected_logits
    loss = -torch.log(torch.sigmoid(beta * logits_diff)).mean()
    return loss

In [15]:
def train_dpo(model, train_loader, num_epochs=3, learning_rate=1e-3):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            chosen_ids = batch['chosen_ids'].to(device)
            rejected_ids = batch['rejected_ids'].to(device)
            input_length = batch['input_length'].to(device)
            
            # Forward pass for chosen and rejected
            chosen_outputs = model(chosen_ids)
            rejected_outputs = model(rejected_ids)
            
            
            batch_size, seq_len, vocab_size = chosen_outputs.logits.size()

            arange = torch.arange(seq_len, device=device)
            mask = (arange >= (input_length.unsqueeze(1) - 1)) & (arange < seq_len - 1)
            mask = mask.unsqueeze(-1).expand(batch_size, seq_len, vocab_size)

            chosen_logits = chosen_outputs.logits * mask
            rejected_logits = rejected_outputs.logits * mask

            attention_mask = chosen_ids != tokenizer.pad_token_id

            # Calculate loss only on the continuation part
            # chosen_logits = chosen_outputs.logits[:, input_length-1:-1, :]
            # rejected_logits = rejected_outputs.logits[:, input_length-1:-1, :]
            # attention_mask = (chosen_ids[:, input_ids.size(1):] != tokenizer.pad_token_id)
            
            loss = dpo_loss(chosen_logits, rejected_logits, attention_mask)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {total_loss/len(train_loader)}")

In [16]:
train_dpo(model, train_loader)

Epoch 1/3, Average Loss: 0.5474512539505959
Epoch 2/3, Average Loss: 0.41895203286409377
Epoch 3/3, Average Loss: 0.34129259134829043
