In [1]:
!pip install wandb datasets transformers



In [2]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from tokenizers import Tokenizer
from pathlib import Path

from torch.utils.data import Dataset, DataLoader
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
!wandb login

In [None]:
wandb.init(
    project = "DPO IMPLEMENTATION"
)

In [5]:
# defining hyperparameters(taken from paper)

batch_size = 1
beta = 0.5#(increasing beta will make the model generate preferred respsonses more)
max_lr = 1e-6

In [6]:
class DPO:
  def __init__(self, ref_model, sft_model, device, beta, tokenizer):


    self.ref_model = ref_model
    self.sft_model = sft_model
    self.device=device
    self.beta = beta
    self.tokenizer = tokenizer
    self.ref_model.eval()



  def DPOloss(self, datapoint):



    self.win_prompt = datapoint['chosen']
    self.lose_prompt = datapoint['rejected']

    with torch.no_grad():
      self.win_log_ref = torch.nn.functional.log_softmax(self.ref_model(**self.win_prompt).logits, dim=-1)
      self.win_log_ref = torch.gather(self.win_log_ref, -1, self.win_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) #Why gather? Because its not token level stuff we care about but se
      self.win_log_ref = self.win_log_ref * (self.win_prompt['attention_mask'])
      
      
      self.lose_log_ref = torch.nn.functional.log_softmax(self.ref_model(**self.lose_prompt).logits, dim=-1)
      self.lose_log_ref = torch.gather(self.lose_log_ref, -1, self.lose_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) #Why gather? Because its not token level stuff we care about but se
      self.lose_log_ref = self.lose_log_ref * (self.lose_prompt['attention_mask'])
      
    self.win_log_sft = torch.nn.functional.log_softmax(self.sft_model(**self.win_prompt).logits, dim=-1)
    self.win_log_sft = torch.gather(self.win_log_sft, -1, self.win_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) #Why gather? Because its not token level stuff we care about but se
    self.win_log_sft = self.win_log_sft * (self.win_prompt['attention_mask'])

    self.lose_log_sft = torch.nn.functional.log_softmax(self.sft_model(**self.lose_prompt).logits, dim=-1)
    self.lose_log_sft = torch.gather(self.lose_log_sft, -1, self.lose_prompt['input_ids'].unsqueeze(-1)).squeeze(-1)
    self.lose_log_sft = self.lose_log_sft * (self.lose_prompt['attention_mask'])

    self.diff1 = self.win_log_sft - self.win_log_ref
    self.diff2 = self.win_log_sft - self.lose_log_ref

    self.final = -nn.functional.logsigmoid(self.beta *(self.diff1 - self.diff2)).mean() #Remember we have to maximize the rewards thus minimizing the negative sign! Also, since the var of rewards could be very much, we take mean so as to have a notion of normalizing it!

   
    return self.final


In [7]:
HF_TOKEN =  "HUGGING_FACE_API_KEY"

In [8]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

cuda:0


In [9]:
torch.cuda.set_device(device)

In [10]:
sft_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", token=HF_TOKEN, device_map=device)
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", token=HF_TOKEN, device_map=device)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct", token=HF_TOKEN, device_map=device)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [11]:
pip install --upgrade accelerate


Note: you may need to restart the kernel to use updated packages.


In [12]:
from datasets import load_dataset, Dataset
dataset = load_dataset("HumanLLMs/Human-Like-DPO-Dataset")
split_dataset = dataset["train"].train_test_split(test_size=0.2, seed=42)
train_dataset = split_dataset["train"]
test_dataset = split_dataset["test"]
print(train_dataset)

Dataset({
    features: ['prompt', 'chosen', 'rejected'],
    num_rows: 8707
})


In [13]:
def dpo_collate_fn_merged_prompt(batch):

    merged_chosen_prompts = []
    merged_rejected_prompts = []

    for sample in batch:

       
        prompt = sample['prompt']
        chosen_data = sample['chosen']
        chosen_data = "Instruction: " + prompt + "\n" + "Output: " + chosen_data + "\n"
        
        rejected_data = sample['rejected']
        rejected_data =  "Instruction: " + prompt + "\n" + "Output: " + rejected_data + "\n"

       
        merged_chosen_prompts.append(chosen_data)


        merged_rejected_prompts.append(rejected_data)

    tokenized_win_prompt = tokenizer(merged_chosen_prompts, max_length = 1024, padding='max_length', truncation=True, return_tensors="pt").to(device)

    tokenized_lose_prompt = tokenizer(merged_rejected_prompts, max_length = 1024, truncation=True, padding='max_length', return_tensors="pt").to(device)



    return {
       
        'chosen': tokenized_win_prompt,
        'rejected': tokenized_lose_prompt 
    }
     

In [14]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=dpo_collate_fn_merged_prompt)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=dpo_collate_fn_merged_prompt)

In [15]:
sft_model.train()
optimizer = torch.optim.RMSprop(sft_model.parameters(), lr=max_lr)

total_steps = 4353
eval_iters = 20

dpo_loss = DPO(ref_model, sft_model, device, beta, tokenizer)

@torch.inference_mode()
def estimate_loss():
    out = {}
    sft_model.eval()
    
    for split in ["train", "test"]:
        if split == "train":
            loader = iter(train_loader)  # Reinitialize train iterator
        else:
            loader = iter(test_loader)   # Reinitialize test iterator

        losses = torch.zeros(eval_iters, device=device)

        for k in range(eval_iters):
            try:
                datapoint = next(loader)  # Get next batch
                loss = dpo_loss.DPOloss(datapoint)
                losses[k] = loss.item()
            except StopIteration:
                break  # Stop if dataset is smaller than eval_iters

        out[split] = losses.mean().item()  # Store mean loss

    sft_model.train()
    return out


In [16]:
# from tqdm import tqdm

# train_iterator = iter(train_loader)

# for step in tqdm(range(total_steps)):


#   if (step % eval_iters == 0 and step != 0) or step == total_steps - 1:
#     losses = estimate_loss()
#     print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['test']:.4f}")
#     wandb.log({
#             "step": step,
#             "training_loss": losses['train'],
#             "val_loss": losses['test']
#         })
    
#     text  = next(train_iterator)
#     loss = dpo_loss.DPOloss(text)
#     optimizer.zero_grad(set_to_none=True)
#     loss.backward()
#     optimizer.step()


In [17]:
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()  # Gradient scaler for stable FP16 training
train_iterator = iter(train_loader)

for step in tqdm(range(total_steps)):

    if (step % eval_iters == 0 and step != 0) or step == total_steps - 1:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['test']:.4f}")
        wandb.log({
            "step": step,
            "training_loss": losses['train'],
            "val_loss": losses['test']
        })

    text = next(train_iterator)

    # Enable mixed precision training
    with autocast(dtype=torch.float16):
        loss = dpo_loss.DPOloss(text)

    scaler.scale(loss).backward()  # Scale loss to prevent underflow
    scaler.step(optimizer)  
    scaler.update()  # Update scaler for next iteration
    optimizer.zero_grad(set_to_none=True)


  0%|          | 21/4353 [00:13<3:11:27,  2.65s/it]

step 20: train loss 1.0758, val loss 0.8482


  1%|          | 41/4353 [00:25<3:11:53,  2.67s/it]

step 40: train loss 0.9138, val loss 1.1219


  1%|▏         | 61/4353 [00:37<3:11:56,  2.68s/it]

step 60: train loss 0.9639, val loss 0.9041


  2%|▏         | 81/4353 [00:50<3:11:26,  2.69s/it]

step 80: train loss 1.0306, val loss 1.1060


  2%|▏         | 101/4353 [01:02<3:10:15,  2.68s/it]

step 100: train loss 0.8909, val loss 1.1069


  3%|▎         | 121/4353 [01:15<3:09:20,  2.68s/it]

step 120: train loss 1.0421, val loss 0.9939


  3%|▎         | 141/4353 [01:27<3:09:00,  2.69s/it]

step 140: train loss 0.9483, val loss 0.9057


  4%|▎         | 161/4353 [01:40<3:08:29,  2.70s/it]

step 160: train loss 0.9502, val loss 0.9469


  4%|▍         | 181/4353 [01:52<3:07:44,  2.70s/it]

step 180: train loss 1.0055, val loss 1.0868


  5%|▍         | 201/4353 [02:04<3:07:11,  2.71s/it]

step 200: train loss 1.0353, val loss 1.0237


  5%|▌         | 221/4353 [02:17<3:06:26,  2.71s/it]

step 220: train loss 1.1063, val loss 1.0332


  6%|▌         | 241/4353 [02:29<3:05:33,  2.71s/it]

step 240: train loss 0.9004, val loss 0.9070


  6%|▌         | 261/4353 [02:42<3:05:14,  2.72s/it]

step 260: train loss 1.1180, val loss 1.0122


  6%|▋         | 281/4353 [02:54<3:04:10,  2.71s/it]

step 280: train loss 0.8313, val loss 0.9207


  7%|▋         | 301/4353 [03:07<3:03:57,  2.72s/it]

step 300: train loss 0.9345, val loss 1.0557


  7%|▋         | 321/4353 [03:20<3:03:02,  2.72s/it]

step 320: train loss 0.9249, val loss 1.0573


  8%|▊         | 341/4353 [03:32<3:03:12,  2.74s/it]

step 340: train loss 0.9646, val loss 0.8619


  8%|▊         | 361/4353 [03:45<3:02:22,  2.74s/it]

step 360: train loss 0.8967, val loss 0.8050


  9%|▉         | 381/4353 [03:57<3:01:59,  2.75s/it]

step 380: train loss 0.9796, val loss 0.9795


  9%|▉         | 401/4353 [04:10<3:01:39,  2.76s/it]

step 400: train loss 1.1827, val loss 0.9175


 10%|▉         | 421/4353 [04:23<3:00:22,  2.75s/it]

step 420: train loss 0.7990, val loss 1.0740


 10%|█         | 441/4353 [04:35<3:00:19,  2.77s/it]

step 440: train loss 0.9341, val loss 1.1314


 11%|█         | 461/4353 [04:48<2:59:54,  2.77s/it]

step 460: train loss 0.8057, val loss 1.0610


 11%|█         | 481/4353 [05:01<2:58:21,  2.76s/it]

step 480: train loss 1.0020, val loss 1.0874


 12%|█▏        | 501/4353 [05:13<2:57:57,  2.77s/it]

step 500: train loss 1.1180, val loss 0.9325


 12%|█▏        | 521/4353 [05:26<2:57:26,  2.78s/it]

step 520: train loss 0.8345, val loss 0.8629


 12%|█▏        | 541/4353 [05:39<2:59:04,  2.82s/it]

step 540: train loss 1.0650, val loss 0.9347


 13%|█▎        | 561/4353 [05:52<2:59:05,  2.83s/it]

step 560: train loss 0.9499, val loss 0.8814


 13%|█▎        | 581/4353 [06:05<2:59:01,  2.85s/it]

step 580: train loss 1.0365, val loss 0.8991


 14%|█▍        | 601/4353 [06:18<2:54:48,  2.80s/it]

step 600: train loss 0.8479, val loss 0.9138


 14%|█▍        | 621/4353 [06:31<2:52:59,  2.78s/it]

step 620: train loss 1.0078, val loss 0.9466


 15%|█▍        | 641/4353 [06:43<2:50:50,  2.76s/it]

step 640: train loss 0.8700, val loss 1.0423


 15%|█▌        | 661/4353 [06:56<2:49:22,  2.75s/it]

step 660: train loss 0.8928, val loss 0.9512


 16%|█▌        | 681/4353 [07:09<2:48:38,  2.76s/it]

step 680: train loss 0.9352, val loss 0.8705


 16%|█▌        | 701/4353 [07:21<2:47:39,  2.75s/it]

step 700: train loss 1.0811, val loss 0.9919


 17%|█▋        | 721/4353 [07:34<2:47:11,  2.76s/it]

step 720: train loss 1.0107, val loss 1.1116


 17%|█▋        | 741/4353 [07:47<2:46:49,  2.77s/it]

step 740: train loss 0.8766, val loss 1.0157


 17%|█▋        | 761/4353 [07:59<2:46:50,  2.79s/it]

step 760: train loss 0.9766, val loss 0.9073


 18%|█▊        | 781/4353 [08:12<2:45:54,  2.79s/it]

step 780: train loss 1.0148, val loss 0.9247


 18%|█▊        | 801/4353 [08:25<2:46:48,  2.82s/it]

step 800: train loss 0.9411, val loss 1.2304


 19%|█▉        | 821/4353 [08:38<2:47:42,  2.85s/it]

step 820: train loss 1.0848, val loss 0.9912


 19%|█▉        | 841/4353 [08:51<2:46:57,  2.85s/it]

step 840: train loss 1.0754, val loss 0.9707


 20%|█▉        | 861/4353 [09:04<2:41:49,  2.78s/it]

step 860: train loss 0.9971, val loss 0.9922


 20%|██        | 881/4353 [09:16<2:39:48,  2.76s/it]

step 880: train loss 1.1014, val loss 0.9037


 21%|██        | 901/4353 [09:29<2:38:11,  2.75s/it]

step 900: train loss 0.8928, val loss 0.9337


 21%|██        | 921/4353 [09:42<2:36:09,  2.73s/it]

step 920: train loss 1.0605, val loss 0.9049


 22%|██▏       | 941/4353 [09:54<2:35:27,  2.73s/it]

step 940: train loss 1.0545, val loss 1.0503


 22%|██▏       | 961/4353 [10:07<2:34:04,  2.73s/it]

step 960: train loss 1.0951, val loss 1.0423


 23%|██▎       | 981/4353 [10:19<2:32:49,  2.72s/it]

step 980: train loss 1.1285, val loss 1.0322


 23%|██▎       | 1001/4353 [10:32<2:32:11,  2.72s/it]

step 1000: train loss 1.1274, val loss 0.9913


 23%|██▎       | 1021/4353 [10:44<2:31:11,  2.72s/it]

step 1020: train loss 0.9284, val loss 1.0452


 24%|██▍       | 1041/4353 [10:57<2:30:24,  2.72s/it]

step 1040: train loss 0.9933, val loss 1.0252


 24%|██▍       | 1061/4353 [11:09<2:29:35,  2.73s/it]

step 1060: train loss 1.0033, val loss 0.8377


 25%|██▍       | 1081/4353 [11:22<2:28:55,  2.73s/it]

step 1080: train loss 0.9644, val loss 1.2261


 25%|██▌       | 1101/4353 [11:35<2:28:34,  2.74s/it]

step 1100: train loss 0.9472, val loss 0.8874


 26%|██▌       | 1121/4353 [11:47<2:28:09,  2.75s/it]

step 1120: train loss 1.1345, val loss 0.9751


 26%|██▌       | 1141/4353 [12:00<2:27:19,  2.75s/it]

step 1140: train loss 1.0515, val loss 0.9358


 27%|██▋       | 1161/4353 [12:12<2:26:18,  2.75s/it]

step 1160: train loss 1.0492, val loss 0.8228


 27%|██▋       | 1181/4353 [12:25<2:25:58,  2.76s/it]

step 1180: train loss 1.1327, val loss 1.0623


 28%|██▊       | 1201/4353 [12:38<2:25:21,  2.77s/it]

step 1200: train loss 1.0501, val loss 0.9081


 28%|██▊       | 1221/4353 [12:51<2:25:37,  2.79s/it]

step 1220: train loss 1.0809, val loss 0.9894


 29%|██▊       | 1241/4353 [13:04<2:26:51,  2.83s/it]

step 1240: train loss 0.8917, val loss 1.0670


 29%|██▉       | 1261/4353 [13:16<2:23:33,  2.79s/it]

step 1260: train loss 1.0146, val loss 1.0067


 29%|██▉       | 1281/4353 [13:29<2:21:30,  2.76s/it]

step 1280: train loss 1.3100, val loss 1.0091


 30%|██▉       | 1301/4353 [13:42<2:20:22,  2.76s/it]

step 1300: train loss 1.1017, val loss 0.9648


 30%|███       | 1321/4353 [13:54<2:19:20,  2.76s/it]

step 1320: train loss 0.9363, val loss 0.9732


 31%|███       | 1341/4353 [14:07<2:17:44,  2.74s/it]

step 1340: train loss 0.9595, val loss 0.7893


 31%|███▏      | 1361/4353 [14:20<2:16:40,  2.74s/it]

step 1360: train loss 0.9465, val loss 1.0590


 32%|███▏      | 1381/4353 [14:32<2:15:24,  2.73s/it]

step 1380: train loss 0.9752, val loss 0.8906


 32%|███▏      | 1401/4353 [14:45<2:14:33,  2.73s/it]

step 1400: train loss 0.9321, val loss 0.9686


 33%|███▎      | 1421/4353 [14:58<2:13:58,  2.74s/it]

step 1420: train loss 1.1250, val loss 1.0109


 33%|███▎      | 1441/4353 [15:10<2:13:17,  2.75s/it]

step 1440: train loss 0.9182, val loss 0.9121


 34%|███▎      | 1461/4353 [15:23<2:13:15,  2.76s/it]

step 1460: train loss 0.9876, val loss 0.9717


 34%|███▍      | 1481/4353 [15:36<2:12:55,  2.78s/it]

step 1480: train loss 0.9647, val loss 1.0179


 34%|███▍      | 1501/4353 [15:48<2:12:51,  2.80s/it]

step 1500: train loss 1.0911, val loss 1.1067


 35%|███▍      | 1521/4353 [16:01<2:13:18,  2.82s/it]

step 1520: train loss 0.9730, val loss 1.0390


 35%|███▌      | 1541/4353 [16:14<2:11:58,  2.82s/it]

step 1540: train loss 0.9351, val loss 1.1033


 36%|███▌      | 1561/4353 [16:27<2:10:41,  2.81s/it]

step 1560: train loss 0.9661, val loss 1.0266


 36%|███▋      | 1581/4353 [16:40<2:09:23,  2.80s/it]

step 1580: train loss 1.0181, val loss 1.0056


 37%|███▋      | 1601/4353 [16:53<2:09:08,  2.82s/it]

step 1600: train loss 0.8623, val loss 1.2604


 37%|███▋      | 1621/4353 [17:06<2:09:07,  2.84s/it]

step 1620: train loss 0.9401, val loss 0.9955


 38%|███▊      | 1641/4353 [17:19<2:07:21,  2.82s/it]

step 1640: train loss 1.0104, val loss 0.9736


 38%|███▊      | 1661/4353 [17:31<2:04:42,  2.78s/it]

step 1660: train loss 0.9495, val loss 0.9329


 39%|███▊      | 1681/4353 [17:44<2:02:54,  2.76s/it]

step 1680: train loss 1.1185, val loss 0.8635


 39%|███▉      | 1701/4353 [17:57<2:01:48,  2.76s/it]

step 1700: train loss 0.8847, val loss 0.9944


 40%|███▉      | 1721/4353 [18:09<2:01:17,  2.77s/it]

step 1720: train loss 0.8900, val loss 0.9350


 40%|███▉      | 1741/4353 [18:22<1:59:40,  2.75s/it]

step 1740: train loss 0.9051, val loss 0.9949


 40%|████      | 1761/4353 [18:35<1:58:33,  2.74s/it]

step 1760: train loss 0.9261, val loss 1.3458


 41%|████      | 1781/4353 [18:47<1:57:17,  2.74s/it]

step 1780: train loss 1.0320, val loss 0.9327


 41%|████▏     | 1801/4353 [19:00<1:56:59,  2.75s/it]

step 1800: train loss 0.9630, val loss 1.0298


 42%|████▏     | 1820/4353 [19:04<08:46,  4.81it/s]  

step 1820: train loss 1.0566, val loss 1.0697


 42%|████▏     | 1841/4353 [19:25<1:55:46,  2.77s/it]

step 1840: train loss 0.9503, val loss 0.9344


 43%|████▎     | 1861/4353 [19:38<1:55:16,  2.78s/it]

step 1860: train loss 0.9876, val loss 0.9941


 43%|████▎     | 1881/4353 [19:51<1:55:13,  2.80s/it]

step 1880: train loss 0.9530, val loss 0.9236


 44%|████▎     | 1901/4353 [20:04<1:56:17,  2.85s/it]

step 1900: train loss 1.1109, val loss 1.1078


 44%|████▍     | 1921/4353 [20:17<1:56:49,  2.88s/it]

step 1920: train loss 0.9774, val loss 1.0098


 45%|████▍     | 1941/4353 [20:30<1:52:38,  2.80s/it]

step 1940: train loss 0.9538, val loss 1.0138


 45%|████▌     | 1961/4353 [20:42<1:50:35,  2.77s/it]

step 1960: train loss 0.9833, val loss 0.9541


 46%|████▌     | 1981/4353 [20:55<1:48:20,  2.74s/it]

step 1980: train loss 0.8670, val loss 0.8680


 46%|████▌     | 2001/4353 [21:08<1:47:06,  2.73s/it]

step 2000: train loss 0.9513, val loss 1.0551


 46%|████▋     | 2021/4353 [21:20<1:45:32,  2.72s/it]

step 2020: train loss 0.9669, val loss 0.9836


 47%|████▋     | 2041/4353 [21:33<1:44:38,  2.72s/it]

step 2040: train loss 0.8790, val loss 1.0142


 47%|████▋     | 2061/4353 [21:45<1:43:29,  2.71s/it]

step 2060: train loss 1.0182, val loss 0.9041


 48%|████▊     | 2081/4353 [21:58<1:42:35,  2.71s/it]

step 2080: train loss 0.9479, val loss 0.9797


 48%|████▊     | 2101/4353 [22:10<1:41:30,  2.70s/it]

step 2100: train loss 0.9212, val loss 0.9755


 49%|████▊     | 2121/4353 [22:23<1:40:50,  2.71s/it]

step 2120: train loss 0.9128, val loss 0.9542


 49%|████▉     | 2141/4353 [22:35<1:40:04,  2.71s/it]

step 2140: train loss 1.0627, val loss 1.0708


 50%|████▉     | 2161/4353 [22:48<1:38:59,  2.71s/it]

step 2160: train loss 1.0176, val loss 0.8797


 50%|█████     | 2181/4353 [23:00<1:38:22,  2.72s/it]

step 2180: train loss 1.0237, val loss 0.9676


 51%|█████     | 2201/4353 [23:13<1:37:45,  2.73s/it]

step 2200: train loss 0.9770, val loss 1.1752


 51%|█████     | 2221/4353 [23:26<1:36:52,  2.73s/it]

step 2220: train loss 1.0437, val loss 1.0026


 51%|█████▏    | 2241/4353 [23:38<1:36:09,  2.73s/it]

step 2240: train loss 0.9096, val loss 1.0354


 52%|█████▏    | 2261/4353 [23:51<1:35:17,  2.73s/it]

step 2260: train loss 1.2736, val loss 1.1113


 52%|█████▏    | 2281/4353 [24:03<1:34:35,  2.74s/it]

step 2280: train loss 0.8031, val loss 0.9516


 53%|█████▎    | 2301/4353 [24:16<1:33:29,  2.73s/it]

step 2300: train loss 0.9736, val loss 0.9822


 53%|█████▎    | 2321/4353 [24:29<1:32:35,  2.73s/it]

step 2320: train loss 0.8718, val loss 0.8435


 54%|█████▍    | 2341/4353 [24:41<1:32:04,  2.75s/it]

step 2340: train loss 1.0036, val loss 0.9531


 54%|█████▍    | 2361/4353 [24:54<1:31:47,  2.76s/it]

step 2360: train loss 0.8691, val loss 1.0558


 55%|█████▍    | 2381/4353 [25:07<1:31:44,  2.79s/it]

step 2380: train loss 0.9134, val loss 0.9073


 55%|█████▌    | 2401/4353 [25:20<1:31:55,  2.83s/it]

step 2400: train loss 1.0767, val loss 0.8668


 56%|█████▌    | 2421/4353 [25:32<1:30:24,  2.81s/it]

step 2420: train loss 0.8884, val loss 0.8869


 56%|█████▌    | 2441/4353 [25:45<1:28:45,  2.79s/it]

step 2440: train loss 1.0898, val loss 1.0747


 57%|█████▋    | 2461/4353 [25:58<1:26:41,  2.75s/it]

step 2460: train loss 0.9531, val loss 0.9417


 57%|█████▋    | 2481/4353 [26:10<1:25:51,  2.75s/it]

step 2480: train loss 0.9000, val loss 0.9150


 57%|█████▋    | 2501/4353 [26:23<1:25:02,  2.75s/it]

step 2500: train loss 0.8737, val loss 0.9383


 58%|█████▊    | 2521/4353 [26:36<1:24:03,  2.75s/it]

step 2520: train loss 1.0413, val loss 1.0076


 58%|█████▊    | 2541/4353 [26:49<1:23:38,  2.77s/it]

step 2540: train loss 0.8969, val loss 1.1108


 59%|█████▉    | 2561/4353 [27:01<1:22:56,  2.78s/it]

step 2560: train loss 1.1100, val loss 1.1655


 59%|█████▉    | 2581/4353 [27:14<1:22:28,  2.79s/it]

step 2580: train loss 0.9627, val loss 1.0269


 60%|█████▉    | 2601/4353 [27:27<1:22:08,  2.81s/it]

step 2600: train loss 1.0854, val loss 0.9852


 60%|██████    | 2621/4353 [27:40<1:22:14,  2.85s/it]

step 2620: train loss 1.1107, val loss 1.0086


 61%|██████    | 2641/4353 [27:53<1:19:59,  2.80s/it]

step 2640: train loss 0.8896, val loss 0.8652


 61%|██████    | 2661/4353 [28:05<1:18:27,  2.78s/it]

step 2660: train loss 1.0374, val loss 0.9851


 62%|██████▏   | 2681/4353 [28:18<1:16:50,  2.76s/it]

step 2680: train loss 0.8438, val loss 0.8949


 62%|██████▏   | 2701/4353 [28:31<1:16:39,  2.78s/it]

step 2700: train loss 0.9900, val loss 0.9125


 63%|██████▎   | 2721/4353 [28:44<1:15:27,  2.77s/it]

step 2720: train loss 1.0476, val loss 0.9239


 63%|██████▎   | 2741/4353 [28:56<1:13:57,  2.75s/it]

step 2740: train loss 1.0111, val loss 0.9357


 63%|██████▎   | 2761/4353 [29:09<1:13:13,  2.76s/it]

step 2760: train loss 1.0478, val loss 1.0354


 64%|██████▍   | 2781/4353 [29:22<1:12:19,  2.76s/it]

step 2780: train loss 0.9601, val loss 0.9038


 64%|██████▍   | 2801/4353 [29:34<1:11:51,  2.78s/it]

step 2800: train loss 1.0706, val loss 0.8794


 65%|██████▍   | 2821/4353 [29:47<1:11:14,  2.79s/it]

step 2820: train loss 1.2366, val loss 0.9146


 65%|██████▌   | 2841/4353 [30:00<1:10:57,  2.82s/it]

step 2840: train loss 0.8565, val loss 0.9362


 66%|██████▌   | 2861/4353 [30:13<1:10:26,  2.83s/it]

step 2860: train loss 0.9885, val loss 1.1562


 66%|██████▌   | 2881/4353 [30:26<1:09:38,  2.84s/it]

step 2880: train loss 0.9894, val loss 0.9507


 67%|██████▋   | 2901/4353 [30:39<1:08:40,  2.84s/it]

step 2900: train loss 0.9906, val loss 1.0022


 67%|██████▋   | 2921/4353 [30:52<1:07:36,  2.83s/it]

step 2920: train loss 0.9700, val loss 1.0276


 68%|██████▊   | 2941/4353 [31:04<1:05:29,  2.78s/it]

step 2940: train loss 0.8731, val loss 0.9390


 68%|██████▊   | 2961/4353 [31:17<1:04:28,  2.78s/it]

step 2960: train loss 0.9220, val loss 1.0532


 68%|██████▊   | 2981/4353 [31:30<1:03:18,  2.77s/it]

step 2980: train loss 0.9461, val loss 0.8765


 69%|██████▉   | 3001/4353 [31:43<1:02:35,  2.78s/it]

step 3000: train loss 0.9590, val loss 1.0164


 69%|██████▉   | 3021/4353 [31:55<1:01:28,  2.77s/it]

step 3020: train loss 0.8798, val loss 0.9785


 70%|██████▉   | 3041/4353 [32:08<1:00:44,  2.78s/it]

step 3040: train loss 0.8637, val loss 0.9314


 70%|███████   | 3061/4353 [32:21<59:49,  2.78s/it]  

step 3060: train loss 1.0124, val loss 1.1753


 71%|███████   | 3081/4353 [32:34<59:10,  2.79s/it]

step 3080: train loss 1.0545, val loss 0.9612


 71%|███████   | 3101/4353 [32:47<58:46,  2.82s/it]

step 3100: train loss 0.9408, val loss 1.0200


 72%|███████▏  | 3121/4353 [33:00<58:41,  2.86s/it]

step 3120: train loss 0.9323, val loss 1.0429


 72%|███████▏  | 3141/4353 [33:12<56:55,  2.82s/it]

step 3140: train loss 0.9466, val loss 1.0119


 73%|███████▎  | 3161/4353 [33:25<55:28,  2.79s/it]

step 3160: train loss 1.0620, val loss 1.0225


 73%|███████▎  | 3181/4353 [33:38<54:08,  2.77s/it]

step 3180: train loss 0.8602, val loss 0.9404


 74%|███████▎  | 3201/4353 [33:51<53:06,  2.77s/it]

step 3200: train loss 0.7796, val loss 0.9886


 74%|███████▍  | 3221/4353 [34:03<52:06,  2.76s/it]

step 3220: train loss 0.9174, val loss 0.9356


 74%|███████▍  | 3241/4353 [34:16<51:17,  2.77s/it]

step 3240: train loss 0.9938, val loss 0.9901


 75%|███████▍  | 3261/4353 [34:29<50:22,  2.77s/it]

step 3260: train loss 1.0455, val loss 0.9367


 75%|███████▌  | 3281/4353 [34:42<49:47,  2.79s/it]

step 3280: train loss 1.0526, val loss 1.1276


 76%|███████▌  | 3301/4353 [34:54<48:51,  2.79s/it]

step 3300: train loss 0.9683, val loss 0.9088


 76%|███████▋  | 3321/4353 [35:07<48:01,  2.79s/it]

step 3320: train loss 0.9731, val loss 1.0536


 77%|███████▋  | 3341/4353 [35:20<47:22,  2.81s/it]

step 3340: train loss 0.8845, val loss 0.9464


 77%|███████▋  | 3361/4353 [35:33<46:50,  2.83s/it]

step 3360: train loss 0.9170, val loss 1.0127


 78%|███████▊  | 3381/4353 [35:46<45:33,  2.81s/it]

step 3380: train loss 0.8947, val loss 0.8903


 78%|███████▊  | 3401/4353 [35:59<44:37,  2.81s/it]

step 3400: train loss 1.0574, val loss 1.0702


 79%|███████▊  | 3421/4353 [36:12<43:43,  2.81s/it]

step 3420: train loss 1.0967, val loss 0.9555


 79%|███████▉  | 3441/4353 [36:24<43:01,  2.83s/it]

step 3440: train loss 1.1132, val loss 0.9735


 80%|███████▉  | 3461/4353 [36:37<42:20,  2.85s/it]

step 3460: train loss 0.9020, val loss 0.9798


 80%|███████▉  | 3481/4353 [36:50<40:29,  2.79s/it]

step 3480: train loss 0.8188, val loss 0.9814


 80%|████████  | 3501/4353 [37:03<39:09,  2.76s/it]

step 3500: train loss 0.8825, val loss 1.0095


 81%|████████  | 3521/4353 [37:16<38:25,  2.77s/it]

step 3520: train loss 0.8605, val loss 0.8587


 81%|████████▏ | 3541/4353 [37:28<37:17,  2.76s/it]

step 3540: train loss 1.1089, val loss 0.9594


 82%|████████▏ | 3561/4353 [37:41<36:20,  2.75s/it]

step 3560: train loss 0.9661, val loss 0.9327


 82%|████████▏ | 3581/4353 [37:54<35:19,  2.75s/it]

step 3580: train loss 0.9699, val loss 1.0119


 83%|████████▎ | 3601/4353 [38:06<34:28,  2.75s/it]

step 3600: train loss 1.0208, val loss 1.1100


 83%|████████▎ | 3621/4353 [38:19<33:37,  2.76s/it]

step 3620: train loss 0.8149, val loss 0.9467


 84%|████████▎ | 3641/4353 [38:32<32:42,  2.76s/it]

step 3640: train loss 0.9649, val loss 0.8865


 84%|████████▍ | 3661/4353 [38:44<31:56,  2.77s/it]

step 3660: train loss 0.9145, val loss 0.9865


 85%|████████▍ | 3681/4353 [38:57<30:54,  2.76s/it]

step 3680: train loss 0.9152, val loss 0.9520


 85%|████████▌ | 3701/4353 [39:10<29:52,  2.75s/it]

step 3700: train loss 0.8435, val loss 0.9487


 85%|████████▌ | 3721/4353 [39:22<29:06,  2.76s/it]

step 3720: train loss 1.0909, val loss 1.1575


 86%|████████▌ | 3741/4353 [39:35<28:21,  2.78s/it]

step 3740: train loss 1.0933, val loss 1.0366


 86%|████████▋ | 3761/4353 [39:48<27:29,  2.79s/it]

step 3760: train loss 0.8792, val loss 0.9757


 87%|████████▋ | 3781/4353 [40:01<26:39,  2.80s/it]

step 3780: train loss 0.9174, val loss 0.9563


 87%|████████▋ | 3801/4353 [40:14<25:56,  2.82s/it]

step 3800: train loss 0.9355, val loss 1.0478


 88%|████████▊ | 3821/4353 [40:27<24:57,  2.81s/it]

step 3820: train loss 0.9750, val loss 0.9155


 88%|████████▊ | 3841/4353 [40:39<23:33,  2.76s/it]

step 3840: train loss 0.8078, val loss 1.0603


 89%|████████▊ | 3861/4353 [40:52<22:50,  2.78s/it]

step 3860: train loss 1.0329, val loss 0.9609


 89%|████████▉ | 3881/4353 [41:05<21:44,  2.76s/it]

step 3880: train loss 0.9775, val loss 0.9610


 90%|████████▉ | 3901/4353 [41:17<20:48,  2.76s/it]

step 3900: train loss 0.9385, val loss 1.1576


 90%|█████████ | 3921/4353 [41:30<19:52,  2.76s/it]

step 3920: train loss 0.9597, val loss 1.0571


 91%|█████████ | 3941/4353 [41:43<18:54,  2.75s/it]

step 3940: train loss 0.9606, val loss 1.2178


 91%|█████████ | 3961/4353 [41:55<17:54,  2.74s/it]

step 3960: train loss 1.1622, val loss 1.0690


 91%|█████████▏| 3981/4353 [42:08<16:53,  2.73s/it]

step 3980: train loss 0.9643, val loss 0.8712


 92%|█████████▏| 4001/4353 [42:21<16:04,  2.74s/it]

step 4000: train loss 0.9783, val loss 0.9306


 92%|█████████▏| 4021/4353 [42:33<15:13,  2.75s/it]

step 4020: train loss 1.0774, val loss 0.9936


 93%|█████████▎| 4041/4353 [42:46<14:17,  2.75s/it]

step 4040: train loss 1.0729, val loss 0.9311


 93%|█████████▎| 4061/4353 [42:59<13:27,  2.77s/it]

step 4060: train loss 1.0171, val loss 0.9097


 94%|█████████▍| 4081/4353 [43:11<12:36,  2.78s/it]

step 4080: train loss 0.9438, val loss 0.8935


 94%|█████████▍| 4101/4353 [43:24<11:48,  2.81s/it]

step 4100: train loss 1.0014, val loss 0.8723


 95%|█████████▍| 4121/4353 [43:37<10:46,  2.79s/it]

step 4120: train loss 1.0110, val loss 1.0528


 95%|█████████▌| 4141/4353 [43:50<09:47,  2.77s/it]

step 4140: train loss 1.0627, val loss 0.9537


 96%|█████████▌| 4161/4353 [44:02<08:48,  2.75s/it]

step 4160: train loss 1.0043, val loss 0.9891


 96%|█████████▌| 4181/4353 [44:15<07:52,  2.75s/it]

step 4180: train loss 0.9740, val loss 0.9617


 97%|█████████▋| 4201/4353 [44:28<06:56,  2.74s/it]

step 4200: train loss 0.8644, val loss 0.9037


 97%|█████████▋| 4221/4353 [44:40<06:02,  2.75s/it]

step 4220: train loss 0.9793, val loss 0.9650


 97%|█████████▋| 4241/4353 [44:53<05:08,  2.75s/it]

step 4240: train loss 1.0113, val loss 0.8949


 98%|█████████▊| 4261/4353 [45:06<04:13,  2.75s/it]

step 4260: train loss 0.8497, val loss 0.8562


 98%|█████████▊| 4281/4353 [45:18<03:18,  2.76s/it]

step 4280: train loss 1.1040, val loss 0.9037


 99%|█████████▉| 4301/4353 [45:31<02:23,  2.76s/it]

step 4300: train loss 0.9678, val loss 0.9770


 99%|█████████▉| 4321/4353 [45:44<01:28,  2.78s/it]

step 4320: train loss 1.0223, val loss 0.9275


100%|█████████▉| 4341/4353 [45:57<00:33,  2.79s/it]

step 4340: train loss 1.0120, val loss 1.0177


100%|██████████| 4353/4353 [46:08<00:00,  1.57it/s]

step 4352: train loss 0.9909, val loss 0.9870





In [20]:
torch.save(sft_model.state_dict(), "dpo_model_weights.pth")