In [None]:
!pip install wandb
!pip install datasets

import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from tokenizers import Tokenizer
from pathlib import Path

from torch.utils.data import Dataset, DataLoader
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer


In [None]:
wandb.init(
            # entity = 'rajceo2031',
                        project = 'GPTJ-DPO',
                        # config = CFG,
                        # save_code = True,
                        #group = 'ANN',
                        #job_type = 'train'
)

In [None]:
#Hyperparameters

batch_size = 128
beta = 2
max_lr = 2e-5
gamma = 1.2
min_lr = 0.1 * max_lr


In [None]:

class CustomLRScheduler:
    def __init__(self, optimizer, warmup_iters, lr_decay_iters, min_lr, max_lr):
        self.optimizer = optimizer
        self.warmup_iters = warmup_iters
        self.lr_decay_iters = lr_decay_iters
        self.min_lr = min_lr
        self.max_lr = max_lr
        self.it = 0
        self._last_lr = [max_lr]  # Initialize with max_lr (matching PyTorch convention)
        
    def step(self):
        
        self._last_lr = [self._get_lr()]  # Store as list
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self._last_lr[0]
        self.it += 1

    def get_last_lr(self):
        return self._last_lr  # Returns list to match PyTorch convention
    
    def _get_lr(self):

        # cycle = math.floor(1 + self.it / (2 * self.warmup_iters))
        # x = abs(self.it / self.warmup_iters - 2 * cycle + 1)
        # return self.min_lr + (self.max_lr - self.min_lr) * max(0, (1 - x))
        
        # 1) linear warmup for warmup_iters steps
        if self.it < self.warmup_iters:
            return self.max_lr * (self.it + 1) / (self.warmup_iters + 1)
        # 2) if it > lr_decay_iters, return min learning rate
        if self.it > self.lr_decay_iters:
            return self.min_lr
        # 3) in between, use cosine decay down to min learning rate
        decay_ratio = (self.it - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters)
        assert 0 <= decay_ratio <= 1
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) 
        return self.min_lr + coeff * (self.max_lr - self.min_lr)
    
    def state_dict(self):
        return {
            'warmup_iters': self.warmup_iters,
            'lr_decay_iters': self.lr_decay_iters,
            'min_lr': self.min_lr,
            'max_lr': self.max_lr,
            'it': self.it
        }
    
    def load_state_dict(self, state_dict):
        self.warmup_iters = state_dict['warmup_iters']
        self.lr_decay_iters = state_dict['lr_decay_iters']
        self.min_lr = state_dict['min_lr']
        self.max_lr = state_dict['max_lr']
        self.it = state_dict['it']



In [None]:
class SimplePO:
  def __init__(self, sft_model, device, beta, gamma, tokenizer):


    self.sft_model = sft_model
    self.device=device
    self.beta = beta
    self.tokenizer = tokenizer

    self.gamma = gamma

  def SimplePOloss(self, datapoint):



    self.win_prompt = datapoint['chosen']
    self.lose_prompt = datapoint['rejected']

    self.win_log_sft = torch.nn.functional.log_softmax(self.sft_model(**self.win_prompt).logits, dim=-1)
    self.win_log_sft = torch.gather(self.win_log_sft, -1, self.win_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) #Why gather? Because its not token level stuff we care about but sequence level. Hence, we will sum up the probs of every token to get seq level but we don't want to do it for attention maksed tokens too. Hence we we will use gather() to get the ids and multiply the probs by the masked out tokens indexes.
    self.win_log_sft = self.win_log_sft * (self.win_prompt['attention_mask'])
    self.win_log_sft = self.win_log_sft.sum(dim=-1)


    self.lose_log_sft = torch.nn.functional.log_softmax(self.sft_model(**self.lose_prompt).logits, dim=-1)
    self.lose_log_sft = torch.gather(self.lose_log_sft, -1, self.lose_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) #Why gather? Because its not token level stuff we care about but sequence level. Hence, we will sum up the probs of every token to get seq level but we don't want to do it for attention maksed tokens too. Hence we we will use gather() to get the ids and multiply the probs by the masked out tokens indexes.
    self.lose_log_sft = self.lose_log_sft * (self.lose_prompt['attention_mask'])
    self.lose_log_sft = self.lose_log_sft.sum(dim=-1)


    self.avg_log_win = self.win_log_sft.mean()
    self.avg_log_lose = self.lose_log_sft.mean()








    
    self.diff1 = (self.win_log_sft / self.avg_log_win) - (self.lose_log_sft / self.avg_log_lose)

    self.final = -nn.functional.logsigmoid(self.beta * (self.diff1) - self.gamma).mean() 

    # sft_model.train()
    return self.final



In [None]:
device='cuda:0'
torch.cuda.set_device(device)

In [None]:

sft_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", token=HF_TOKEN, device_map=device)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct", token=HF_TOKEN, device_map=device)

In [None]:
from datasets import load_dataset, Dataset

train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train", token=HF_TOKEN)
val_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test", token=HF_TOKEN)


In [None]:
def dpo_collate_fn_merged_prompt(batch):

    merged_chosen_prompts = []
    merged_rejected_prompts = []

    for sample in batch:

        # print(sample)

        # Extract and merge chosen response
        prompt = sample['prompt']
        chosen_data = sample['chosen']
        chosen_data = "Instruction: " + prompt + "\n" + "Output: " + chosen_data[1]['content'] + "\n"
        # Extract and merge rejected response
        rejected_data = sample['rejected']
        rejected_data =  "Instruction: " + prompt + "\n" + "Output: " + rejected_data[1]['content'] + "\n"

        # print(chosen_data)
        # print(rejected_data)
        merged_chosen_prompts.append(chosen_data)


        merged_rejected_prompts.append(rejected_data)

    tokenized_win_prompt = tokenizer(merged_chosen_prompts, max_length = 1024, padding='max_length', truncation=True, return_tensors="pt").to(device)

    tokenized_lose_prompt = tokenizer(merged_rejected_prompts, max_length = 1024, truncation=True, padding='max_length', return_tensors="pt").to(device)



    return {
        # 'prompt': prompts, # Still return original prompts for potential use
        'chosen': tokenized_win_prompt, # List of merged prompt-chosen texts
        'rejected': tokenized_lose_prompt # List of merged prompt-rejected texts
    }

In [None]:


from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=dpo_collate_fn_merged_prompt)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=dpo_collate_fn_merged_prompt)

In [None]:

# Optimizer setup and scheduler steup
sft_model.train()

val_iterator = iter(val_loader)
train_itertaor = iter(train_loader)


warmup_iters = 0.1 * len(train_itertaor)
lr_decay_iters = len(train_itertaor) 
optimizer = torch.optim.AdamW(sft_model.parameters(), lr=max_lr)
scheduler = CustomLRScheduler(optimizer, warmup_iters, lr_decay_iters, min_lr, max_lr)

epoch = 1
eval_iters = 20

simplepo_loss = SimplePO(sft_model, device, beta, gamma, tokenizer)





@torch.inference_mode()
def estimate_loss():
    loader = None
    out = {}
    sft_model.eval()
    for split in ['train', 'val']:
        if(split == 'train'):
            loader = train_itertaor

        elif (split == 'val'):
            loader = val_iterator

        
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):

            datapoint = next(loader)


            loss = simplepo_loss.SimplePOloss(datapoint)

            losses[k] = loss.item()
        out[split] = losses.mean()
    sft_model.train()
    return out

In [None]:

#Train the  model
from tqdm import tqdm




for epoch in epoch:
    for step in tqdm(range(len(train_iterator))):


        if (step  % eval_iters == 0 and step != 0) or step == len(train_iterator) - 1:
            losses = estimate_loss()
            print(f"epoch {epoch}, step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
            wandb.log({
                "epoch": epoch,
                "step": step,
                "training_loss": losses['train'],
                "val_loss": losses['val']
            })

        text  = next(train_iterator)


        loss = simplepo_loss.SimplePOloss(text)


        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()