In [1]:
!pip install wandb
!pip install datasets

import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from tokenizers import Tokenizer
from pathlib import Path

from torch.utils.data import Dataset, DataLoader
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer




In [2]:
wandb.init(
            # entity = 'rajceo2031',
                        project = 'GPTJ-DPO',
                        # config = CFG,
                        # save_code = True,
                        #group = 'ANN',
                        #job_type = 'train'
)

[34m[1mwandb[0m: Currently logged in as: [33mrajceo2031[0m ([33mrentio[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [3]:
#Hyperparameters

batch_size = 2
beta = 0.1
max_lr = 1e-6

In [4]:
class DPO:
  def __init__(self, ref_model, sft_model, device, beta, tokenizer):


    self.ref_model = ref_model
    self.sft_model = sft_model
    self.device=device
    self.beta = beta
    self.tokenizer = tokenizer
    self.ref_model.eval()



  def DPOloss(self, datapoint):



    self.win_prompt = datapoint['chosen']
    self.lose_prompt = datapoint['rejected']

    with torch.no_grad():
      self.win_log_ref = torch.nn.functional.log_softmax(self.ref_model(**self.win_prompt).logits, dim=-1)
      self.lose_log_ref = torch.nn.functional.log_softmax(self.ref_model(**self.lose_prompt).logits, dim=-1)

    self.win_log_sft = torch.nn.functional.log_softmax(self.sft_model(**self.win_prompt).logits, dim=-1)
    self.lose_log_sft = torch.nn.functional.log_softmax(self.sft_model(**self.lose_prompt).logits, dim=-1)

    self.diff1 = self.win_log_sft - self.win_log_ref
    self.diff2 = self.win_log_sft - self.lose_log_ref

    self.final = -nn.functional.logsigmoid(self.beta *(self.diff1 - self.diff2)).mean() #Remember we have to maximize the rewards thus minimizing the negative sign! Also, since the var of rewards could be very much, we take mean so as to have a notion of normalizing it!

    # sft_model.train()
    return self.final



In [5]:
# !huggingface-cli login
from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')


In [6]:
device='cuda:0'

In [7]:
torch.cuda.set_device(device)

In [8]:

sft_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", token=HF_TOKEN, device_map=device)
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", token=HF_TOKEN, device_map=device)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct", token=HF_TOKEN, device_map=device)

In [9]:
from datasets import load_dataset, Dataset

train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train", token=HF_TOKEN)
val_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test", token=HF_TOKEN)


In [11]:
def dpo_collate_fn_merged_prompt(batch):

    merged_chosen_prompts = []
    merged_rejected_prompts = []

    for sample in batch:


        # Extract and merge chosen response
        chosen_data = sample['chosen']
        chosen_data = "Instruction: " + chosen_data[0]['content'] + "\n" + "Output: " + chosen_data[1]['content'] + "\n"
        # Extract and merge rejected response
        rejected_data = sample['rejected']
        rejected_data =  "Instruction: " + rejected_data[0]['content'] + "\n" + "Output: " + rejected_data[1]['content'] + "\n"

        merged_chosen_prompts.append(chosen_data)


        merged_rejected_prompts.append(rejected_data)

    tokenized_win_prompt = tokenizer(merged_chosen_prompts, max_length = 1024, padding='max_length', truncation=True, return_tensors="pt").to(device)

    tokenized_lose_prompt = tokenizer(merged_rejected_prompts, max_length = 1024, truncation=True, padding='max_length', return_tensors="pt").to(device)



    return {
        # 'prompt': prompts, # Still return original prompts for potential use
        'chosen': tokenized_win_prompt, # List of merged prompt-chosen texts
        'rejected': tokenized_lose_prompt # List of merged prompt-rejected texts
    }

In [12]:


from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=dpo_collate_fn_merged_prompt)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=dpo_collate_fn_merged_prompt)

In [13]:
sample_text = next(iter(train_loader))

In [14]:
sample_text

{'chosen': {'input_ids': tensor([[ 16664,     25,   3988,  ..., 151643, 151643, 151643],
         [ 16664,     25,  62665,  ..., 151643, 151643, 151643]],
        device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')},
 'rejected': {'input_ids': tensor([[ 16664,     25,   3988,  ..., 151643, 151643, 151643],
         [ 16664,     25,  62665,  ..., 151643, 151643, 151643]],
        device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')}}

In [17]:

# Optimizer setup and scheduler steup
sft_model.train()
optimizer = torch.optim.AdamW(sft_model.parameters(), lr=max_lr)

total_steps = 3000
eval_iters = 20




@torch.inference_mode()
def estimate_loss():
    loader = None
    out = {}
    sft_model.eval()
    for split in ['train', 'val']:
        if(split == 'train'):
            loader = train_loader

        elif (split == 'val'):
            loader = val_loader


        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):

            datapoint = next(iter(loader))


            loss = dpo_loss.DPOloss(datapoint)

            losses[k] = loss.item()
        out[split] = losses.mean()
    sft_model.train()
    return out

In [None]:

#Train the  model
from tqdm import tqdm


dpo_loss = DPO(ref_model, sft_model, device, beta, tokenizer)



for step in tqdm(range(total_steps)):


    if (step  % eval_iters == 0 and step != 0) or step == total_steps - 1:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        wandb.log({
            "step": step,
            "training_loss": losses['train'],
            "val_loss": losses['val']
        })

    text  = next(iter(train_loader))


    loss = dpo_loss.DPOloss(text)


    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


  1%|          | 20/3000 [00:18<45:12,  1.10it/s]

step 20: train loss 0.7008, val loss 0.7105


  1%|▏         | 40/3000 [01:00<45:12,  1.09it/s]

step 40: train loss 0.6881, val loss 0.6841


  2%|▏         | 60/3000 [01:42<45:01,  1.09it/s]

step 60: train loss 0.7204, val loss 0.6965


  3%|▎         | 80/3000 [02:23<44:42,  1.09it/s]

step 80: train loss 0.6770, val loss 0.7152


  3%|▎         | 100/3000 [03:05<44:26,  1.09it/s]

step 100: train loss 0.7017, val loss 0.6945


  4%|▍         | 120/3000 [03:47<44:08,  1.09it/s]

step 120: train loss 0.7079, val loss 0.6953


  5%|▍         | 140/3000 [04:29<43:52,  1.09it/s]

step 140: train loss 0.6950, val loss 0.7017


  5%|▌         | 160/3000 [05:11<43:25,  1.09it/s]

step 160: train loss 0.7146, val loss 0.6988


  6%|▌         | 180/3000 [05:53<43:20,  1.08it/s]

step 180: train loss 0.6920, val loss 0.6881


  7%|▋         | 200/3000 [06:35<43:00,  1.09it/s]

step 200: train loss 0.6955, val loss 0.7029


  7%|▋         | 220/3000 [07:17<42:40,  1.09it/s]

step 220: train loss 0.6893, val loss 0.6949


  8%|▊         | 240/3000 [07:59<42:23,  1.09it/s]

step 240: train loss 0.7257, val loss 0.6950


  9%|▊         | 260/3000 [08:41<42:00,  1.09it/s]

step 260: train loss 0.7017, val loss 0.6944


  9%|▉         | 280/3000 [09:23<41:46,  1.09it/s]

step 280: train loss 0.6770, val loss 0.7089


 10%|█         | 300/3000 [10:05<41:27,  1.09it/s]

step 300: train loss 0.7158, val loss 0.6870


 11%|█         | 320/3000 [10:47<41:07,  1.09it/s]

step 320: train loss 0.7076, val loss 0.6806


 11%|█▏        | 340/3000 [11:29<40:50,  1.09it/s]

step 340: train loss 0.6927, val loss 0.6890


 12%|█▏        | 360/3000 [12:11<40:29,  1.09it/s]

step 360: train loss 0.6979, val loss 0.7115


 13%|█▎        | 380/3000 [12:53<40:10,  1.09it/s]

step 380: train loss 0.7059, val loss 0.6992


 13%|█▎        | 400/3000 [13:35<39:53,  1.09it/s]

step 400: train loss 0.7013, val loss 0.6882


 14%|█▍        | 420/3000 [14:17<39:37,  1.09it/s]

step 420: train loss 0.6866, val loss 0.6772


 15%|█▍        | 440/3000 [14:59<39:18,  1.09it/s]

step 440: train loss 0.7051, val loss 0.6928


 15%|█▌        | 460/3000 [15:41<38:59,  1.09it/s]