In [None]:
%pip install wandb
%pip install datasets 

In [3]:
import random 
import torch   
import torch.nn as nn 
import torch.nn.functional as F
import math  
from dataclasses import dataclass 
from tokenizers import Tokenizer 
from pathlib import Path 
from torch.utils.data import Dataset , DataLoader 
import wandb 
from transformers import AutoTokenizer , AutoModelForCausalLM

In [29]:
@dataclass 
class Config:
    dpo_model_name : str = "mistralai/Mistral-7B-Instruct"
    model_name : str = "Qwen/Qwen2-0.5B-Instruct"
    dataset_name: str = "trl-lib/ultrafeedback_binarized"
    batch_size : int = 2 
    beta : float = 0.1 
    learning_rate : float = 1e-4 
    HF_TOKEN : str = "PASS"
    device : str = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
wandb.init(project="dpo-finetune" , 
           config = Config)


In [10]:
class DPO:
    def __init__(self , config: Config):
        self.config = Config()
        self.ref_model = AutoModelForCausalLM.from_pretrained(self.config.dpo_model_name , token = self.config.HF_TOKEN)
        self.model = AutoModelForCausalLM.from_pretrained(self.config.model_name , token = self.config.HF_TOKEN) 
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name , token = self.config.HF_TOKEN) 
        self.ref_model.eval()
    
    def DPOLoss(self , datapoint):
        self.win_prompt  = datapoint['chosen']
        self.lose_prompt = datapoint["rejected"]
        """Compute the DPO loss for a single datapoint"""
        with torch.no_grad():
            self.win_ref_model = torch.nn.functional.log_softmax(self.ref_model(self.win_prompt).logits , dim = -1)
            self.lose_ref_model = torch.nn.functional.log_softmax(self.ref_model(self.lose_prompt).logits , dim = -1)
            self.win_ref_model = torch.gather(self.win_ref_model , dim = -1 , index = self.win_prompt['input_ids'].unsqueeze(-1)).squeeze(-1)
            self.lose_ref_model = torch.gather(self.lose_ref_model , dim = -1 , index =  self.lose_prompt['input_ids'].unsqueeze(-1)).squeeze(-1)
            self.win_ref_model = self.win_ref_model * self.win_prompt['attention_mask']
            self.lose_ref_model = self.lose_ref_model * self.lose_prompt['attention_mask']
            self.win_ref_model = self.win_ref_model.sum(dim = -1)
            self.lose_ref_model = self.lose_ref_model.sum(dim = -1)
            
            self.win_model = torch.nn.functional.log_softmax(self.model(self.win_prompt).logits , dim = -1)
            self.lose_model = torch.nn.functional.log_softmax(self.model(self.lose_prompt).logits , dim = -1)
            self.win_model = torch.gather(self.win_model , dim = -1  , index = self.win_prompt['input_ids'].unsqueeze(-1)).squeeze(-1)
            self.lose_model = torch.gather(self.lose_model , dim = -1  , index = self.lose_prompt['input_ids'].unsqueeze(-1)).squeeze(-1)
            self.win_model = self.win_model * self.win_prompt['attention_mask']
            self.lose_model = self.lose_model * self.lose_prompt['attention_mask']
            self.win_model = self.win_model.sum(dim = -1)
            self.lose_model = self.lose_model.sum(dim = -1)
            
            self.diff1 = self.win_model - self.win_ref_model
            self.diff2 = self.lose_model - self.lose_ref_model 
            self.loss = - nn.functional.logsigmoid(self.config.beta * (self.diff1 - self.diff2)).mean()
            
            return self.loss

In [16]:
print(Config.device)

cpu


In [None]:
#Device - Do not run this cell untill you get Device SSH key 
torch.cuda.set_device(Config.device)

In [18]:
from datasets import load_dataset, Dataset

train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
val_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test" )

README.md:   0%|          | 0.00/643 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/131M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/2.14M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/62135 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [25]:
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

In [23]:
sample = train_dataset[0]
print(sample)


{'chosen': [{'content': 'Use the pygame library to write a version of the classic game Snake, with a unique twist', 'role': 'user'}, {'content': "Sure, I'd be happy to help you write a version of the classic game Snake using the pygame library! Here's a basic outline of how we can approach this:\n\n1. First, we'll need to set up the game display and create a game object that we can use to handle the game's state.\n2. Next, we'll create the game's grid, which will be used to represent the game board. We'll need to define the size of the grid and the spaces within it.\n3. After that, we'll create the snake object, which will be used to represent the player's movement. We'll need to define the size of the snake and the speed at which it moves.\n4. We'll also need to create a food object, which will be used to represent the food that the player must collect to score points. We'll need to define the location of the food and the speed at which it moves.\n5. Once we have these objects set up,

#### Training and dataset final CELLs

In [24]:
def dpo_collate_fn_merged_prompt(batch):

    merged_chosen_prompts = []
    merged_rejected_prompts = []

    for sample in batch:

        # Extract and merge chosen response
        prompt = sample['prompt']
        chosen_data = sample['chosen']
        chosen_data = "Instruction: " + prompt + "\n" + "Output: " + chosen_data[1]['content'] + "\n"
        # Extract and merge rejected response
        rejected_data = sample['rejected']
        rejected_data =  "Instruction: " + prompt + "\n" + "Output: " + rejected_data[1]['content'] + "\n"
        # print(chosen_data)
        # print(rejected_data)
        merged_chosen_prompts.append(chosen_data)
        merged_rejected_prompts.append(rejected_data)
    tokenized_win_prompt = tokenizer(merged_chosen_prompts, max_length = 1024, padding='max_length', truncation=True, return_tensors="pt")
    tokenized_lose_prompt = tokenizer(merged_rejected_prompts, max_length = 1024, truncation=True, padding='max_length', return_tensors="pt")

    return {
        'chosen': tokenized_win_prompt,
        'rejected': tokenized_lose_prompt 
    }
     


In [30]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, collate_fn=dpo_collate_fn_merged_prompt)
val_loader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=True, collate_fn=dpo_collate_fn_merged_prompt)
     

In [None]:
model.train()
optimizer = torch.optim.AdamW(model.parameters() , lr = Config.learning_rate)
total_steps = len(train_dataset) // Config.batch_size
eval_iterations = 100 
dpo_loss = DPO(Config)
train_iterator = iter(train_loader)
val_iterator = iter(val_loader)

@torch.inference_mode()
def estimate_loss():
    loader = None 
    out = {}
    model.eval()
    with torch.no_grad():
        for split in ['train' , 'val']:
            if (split == 'train'):
                loader = train_loader
            elif(split == 'val'):
                loader = val_loader
            losses = torch.zeros(eval_iterations)
            for k in range(eval_iterations):
                datapoint = next(loader)
                loss = dpo_loss.DPOLoss(datapoint)
                losses[k] = loss.item()
            out[split] = losses.mean()
        model.train()
        return out

In [None]:
#train the model 
from tqdm import tqdm 
train_iterator = iter(train_loader)

for step in tqdm(range(total_steps)):
    losses = estimate_loss()
    print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    wandb.log({
            "step": step,
            "training_loss": losses['train'],
            "val_loss": losses['val']
        })
    text = next(train_iterator)
    loss = dpo_loss.DPOLoss(text)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        wandb.log({
            "step": step,
            "training_loss": losses['train'],
            "val_loss": losses['val']})
        
print("Training Complete!")