In [None]:
!pip install datasets wandb trl
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from tokenizers import Tokenizer
from pathlib import Path

from torch.utils.data import Dataset, DataLoader
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer

In [15]:
device = "cuda:0"

In [None]:
import wandb
!wandb login


In [None]:
wandb.init(
    project = "ORPO"
)

In [None]:
!pip install peft
import gc
import os

import torch
import wandb
from datasets import load_dataset

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
from  trl import setup_chat_format




In [48]:
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install -qqq flash-attn
    attn_implementation = "flash_attention_2"
    torch_dtype = torch.bfloat16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float16

In [49]:
attn_implementation

'flash_attention_2'

In [50]:
class ORPO:
  def __init__(self, model, device, tokenizer):


    self.model = model
    self.device=device

    self.tokenizer = tokenizer
    



  def ORloss(self, datapoint):



    self.win_prompt = datapoint['chosen']
    self.lose_prompt = datapoint['rejected']

   
    self.chosen_log_probs = torch.nn.functional.log_softmax(self.model(**self.win_prompt).logits, dim=-1)
   
    self.chosen_log_probs = torch.gather(self.chosen_log_probs, -1, self.win_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) 
    
    self.chosen_log_probs = self.chosen_log_probs * (self.win_prompt['attention_mask'])
    
    self.chosen_log_probs = self.chosen_log_probs.sum(dim=-1)
    

    self.rejected_log_probs = torch.nn.functional.log_softmax(self.model(**self.lose_prompt).logits, dim=-1)
    self.rejected_log_probs = torch.gather(self.rejected_log_probs, -1, self.lose_prompt['input_ids'].unsqueeze(-1)).squeeze(-1)
    self.rejected_log_probs = self.rejected_log_probs * (self.lose_prompt['attention_mask'])
    self.rejected_log_probs = self.rejected_log_probs.sum(dim=-1)
    
      
    self.log_odds1 = torch.log1p(torch.exp(self.chosen_log_probs)) - (1 - torch.log1p(torch.exp(self.chosen_log_probs)))
    self.log_odds2 = torch.log1p(torch.exp(self.rejected_log_probs)) - (1 - torch.log1p(torch.exp(self.rejected_log_probs))) 
    

     
    self.OR = -nn.functional.logsigmoid(self.log_odds1 - self.log_odds2).mean()

    return self.OR  

In [74]:

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig  
import torch


base_model = "HuggingFaceTB/SmolLM2-135M"
new_model = "Orpo-SMALLM-v2-135M"


torch_dtype = torch.float16


quantization_config = BitsAndBytesConfig(load_in_8bit=True)


peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")



# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=quantization_config,
    device_map="auto",
    attn_implementation=attn_implementation  
)



def prepare_model_for_kbit_training(model):
 
    return model

model = prepare_model_for_kbit_training(model)


In [53]:
dataset_name = "argilla/distilabel-math-preference-dpo"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(1000))


dataset = dataset.train_test_split(test_size=0.01)

In [54]:
train_dataset = dataset["train"]
val_dataset = dataset["test"]
print(train_dataset)

Dataset({
    features: ['metadata', 'instruction', 'chosen_response', 'chosen_rating', 'rejected_response', 'rejected_rating'],
    num_rows: 990
})


In [55]:
def orpo_collate_fn_merged_prompt(batch):

    merged_chosen_prompts = []
    merged_rejected_prompts = []

    for sample in batch:

        
        prompt = sample['instruction']
        chosen_data = sample['chosen_response']
        chosen_data = "Instruction: " + prompt + "\n" + "Output: " + chosen_data + "\n"
      
        rejected_data = sample['rejected_response']
        rejected_data =  "Instruction: " + prompt + "\n" + "Output: " + rejected_data + "\n"


        merged_chosen_prompts.append(chosen_data)


        merged_rejected_prompts.append(rejected_data)

    tokenized_win_prompt = tokenizer(merged_chosen_prompts, max_length = 1024, padding='max_length', truncation=True, return_tensors="pt").to(device)

    tokenized_lose_prompt = tokenizer(merged_rejected_prompts, max_length = 1024, truncation=True, padding='max_length', return_tensors="pt").to(device)



    return {
       
        'chosen': tokenized_win_prompt, 
        'rejected': tokenized_lose_prompt 
    }

In [56]:
batch_size = 2 
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=orpo_collate_fn_merged_prompt)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=orpo_collate_fn_merged_prompt)

In [57]:
for batch in train_loader:
    print(batch)
    break

{'chosen': {'input_ids': tensor([[25464,    42,  1073,  ...,     2,     2,     2],
        [25464,    42, 16222,  ...,     2,     2,     2]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')}, 'rejected': {'input_ids': tensor([[25464,    42,  1073,  ...,     2,     2,     2],
        [25464,    42, 16222,  ...,     2,     2,     2]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')}}


In [58]:
batch_size = 2
beta = 0.2
max_lr = 8e-6
betas = (0.95, 0.99)
weight_decay=0.1

In [71]:
model = model

In [79]:


# 2. Freeze the base model's parameters
for param in model.parameters():
    param.requires_grad = False  # Freeze the parameter
    

# 3. Configure PEFT (e.g., LoRA)
peft_config = LoraConfig(
    r=8,  # Rank of the LoRA matrices
    lora_alpha=32, 
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM" 
)

# 4. Apply PEFT
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()  # Verify trainable parameters.  Should be PEFT params only

r

# If you want to use bitsandbytes Adam8bit, then
import bitsandbytes as bnb
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=1e-4)

trainable params: 460,800 || all params: 134,975,808 || trainable%: 0.3414


In [81]:
import bitsandbytes as bnb
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm  # Import tqdm if you haven't already



# Add batch_size here
total_steps = 2 * len(train_loader)
batch_size = train_loader.batch_size  # Access batch size from the train_loader

# Assuming val_loader and train_loader are already defined
model.train()

val_iterator = iter(val_loader)
train_iterator = iter(train_loader)


@torch.inference_mode()
def estimate_loss(batch_size): 
    out = {}
    model.eval()

    # Create a new validation loader iterator each time
    temp_val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True,
                                 collate_fn=orpo_collate_fn_merged_prompt)
    temp_val_iterator = iter(temp_val_loader)

    for split in ['val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
           
            try:
                text = next(temp_val_iterator)
            except StopIteration:
                
                temp_val_iterator = iter(temp_val_loader)
                text = next(temp_val_iterator)

            targets = text['chosen']['input_ids']
            logits = model(**text['chosen']).logits
            logits = logits[..., :-1, :].contiguous()
            targets = targets[..., 1:].contiguous()

            batch_size, block_size, embeddings_dims = logits.shape
            logits = logits.view(batch_size * block_size, embeddings_dims)
            targets = targets.view(batch_size * block_size)

            loss = torch.nn.functional.cross_entropy(logits, targets,
                                                       ignore_index=tokenizer.pad_token_id) + beta * orpo.ORloss(
                text)
            losses[k] = loss.item()

        out[split] = losses.mean()

    model.train()
    return out



for step in tqdm(range(total_steps)):
    if (step % eval_iters == 0 and step != 0) or step == total_steps - 1:
        losses = estimate_loss(batch_size)  # Pass batch_size here
        print(f"step {step}: val loss {losses['val']:.4f}")
       
        wandb.log({
            "step": step,
            "val_loss": losses['val']
        })

    try:
        text = next(train_iterator)
    except StopIteration:
        train_iterator = iter(train_loader)
        text = next(train_iterator)

    targets = text['chosen']['input_ids']
    logits = model(**text['chosen']).logits
    targets = targets[..., 1:].contiguous()
    logits = logits[..., :-1, :].contiguous()

    batch_size, block_size, vocab_size = logits.shape
    logits = logits.view(batch_size * block_size, vocab_size)
    targets = targets.view(batch_size * block_size)
    loss = torch.nn.functional.cross_entropy(logits, targets,
                                               ignore_index=tokenizer.pad_token_id) + beta * orpo.ORloss(text)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    wandb.log({
        "step": step,
        "training_loss": loss.item()
    })

  2%|▏         | 20/990 [00:14<11:43,  1.38it/s]

step 20: val loss 1.4689


  4%|▍         | 40/990 [00:36<12:13,  1.29it/s]

step 40: val loss 1.4603


  6%|▌         | 60/990 [00:58<11:20,  1.37it/s]

step 60: val loss 1.4699


  8%|▊         | 80/990 [01:21<11:06,  1.37it/s]

step 80: val loss 1.4698


 10%|█         | 100/990 [01:43<11:06,  1.33it/s]

step 100: val loss 1.4646


 12%|█▏        | 120/990 [02:05<11:39,  1.24it/s]

step 120: val loss 1.4723


 14%|█▍        | 140/990 [02:27<10:20,  1.37it/s]

step 140: val loss 1.4567


 16%|█▌        | 160/990 [02:50<10:11,  1.36it/s]

step 160: val loss 1.4689


 18%|█▊        | 180/990 [03:12<09:51,  1.37it/s]

step 180: val loss 1.4535


 20%|██        | 200/990 [03:34<09:50,  1.34it/s]

step 200: val loss 1.4486


 22%|██▏       | 220/990 [03:56<09:26,  1.36it/s]

step 220: val loss 1.4496


 24%|██▍       | 240/990 [04:18<09:17,  1.35it/s]

step 240: val loss 1.4552


 26%|██▋       | 260/990 [04:41<08:53,  1.37it/s]

step 260: val loss 1.4635


 28%|██▊       | 280/990 [05:03<08:55,  1.33it/s]

step 280: val loss 1.4469


 30%|███       | 300/990 [05:25<08:20,  1.38it/s]

step 300: val loss 1.4433


 32%|███▏      | 320/990 [05:47<08:13,  1.36it/s]

step 320: val loss 1.4607


 34%|███▍      | 340/990 [06:09<08:00,  1.35it/s]

step 340: val loss 1.4554


 36%|███▋      | 360/990 [06:31<08:16,  1.27it/s]

step 360: val loss 1.4646


 38%|███▊      | 380/990 [06:53<07:24,  1.37it/s]

step 380: val loss 1.4649


 40%|████      | 400/990 [07:16<07:10,  1.37it/s]

step 400: val loss 1.4351


 42%|████▏     | 420/990 [07:38<06:52,  1.38it/s]

step 420: val loss 1.4374


 44%|████▍     | 440/990 [08:00<06:48,  1.35it/s]

step 440: val loss 1.4404


 46%|████▋     | 460/990 [08:23<06:54,  1.28it/s]

step 460: val loss 1.4420


 48%|████▊     | 480/990 [08:45<06:13,  1.37it/s]

step 480: val loss 1.4338


 51%|█████     | 500/990 [09:07<05:58,  1.37it/s]

step 500: val loss 1.4288


 53%|█████▎    | 520/990 [09:30<05:48,  1.35it/s]

step 520: val loss 1.4497


 55%|█████▍    | 540/990 [09:52<05:43,  1.31it/s]

step 540: val loss 1.4287


 57%|█████▋    | 560/990 [10:14<05:13,  1.37it/s]

step 560: val loss 1.4505


 59%|█████▊    | 580/990 [10:36<05:02,  1.35it/s]

step 580: val loss 1.4369


 61%|██████    | 600/990 [10:58<04:47,  1.36it/s]

step 600: val loss 1.4460


 63%|██████▎   | 620/990 [11:20<04:28,  1.38it/s]

step 620: val loss 1.4375


 65%|██████▍   | 640/990 [11:43<04:23,  1.33it/s]

step 640: val loss 1.4458


 67%|██████▋   | 660/990 [12:05<04:04,  1.35it/s]

step 660: val loss 1.4453


 69%|██████▊   | 680/990 [12:27<03:47,  1.36it/s]

step 680: val loss 1.4460


 71%|███████   | 700/990 [12:49<03:34,  1.35it/s]

step 700: val loss 1.4443


 73%|███████▎  | 720/990 [13:11<03:17,  1.37it/s]

step 720: val loss 1.4281


 75%|███████▍  | 740/990 [13:34<03:07,  1.33it/s]

step 740: val loss 1.4294


 77%|███████▋  | 760/990 [13:56<02:49,  1.36it/s]

step 760: val loss 1.4421


 79%|███████▉  | 780/990 [14:18<02:32,  1.38it/s]

step 780: val loss 1.4379


 81%|████████  | 800/990 [14:40<02:17,  1.38it/s]

step 800: val loss 1.4362


 83%|████████▎ | 820/990 [15:02<02:07,  1.34it/s]

step 820: val loss 1.4302


 85%|████████▍ | 840/990 [15:25<01:52,  1.33it/s]

step 840: val loss 1.4172


 87%|████████▋ | 860/990 [15:47<01:35,  1.36it/s]

step 860: val loss 1.4346


 89%|████████▉ | 880/990 [16:09<01:21,  1.36it/s]

step 880: val loss 1.4305


 91%|█████████ | 900/990 [16:31<01:10,  1.27it/s]

step 900: val loss 1.4117


 93%|█████████▎| 920/990 [16:54<00:51,  1.37it/s]

step 920: val loss 1.4349


 95%|█████████▍| 940/990 [17:16<00:36,  1.38it/s]

step 940: val loss 1.4304


 97%|█████████▋| 960/990 [17:38<00:22,  1.33it/s]

step 960: val loss 1.4572


 99%|█████████▉| 980/990 [18:00<00:07,  1.36it/s]

step 980: val loss 1.4308


100%|█████████▉| 989/990 [18:14<00:00,  1.14it/s]

step 989: val loss 1.4153


100%|██████████| 990/990 [18:22<00:00,  1.11s/it]
