In [1]:
!pip install transformers
!pip install datasets
!pip install trl
!pip install wandb

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [2]:
import argparse
import random
import numpy as np

import torch

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

from trl import DPOTrainer

import wandb

In [3]:
import argparse
import random
import numpy as np
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW

from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

import wandb
from tqdm import tqdm


In [4]:
def seed_everything(seed=42):
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False
  np.random.seed(seed)

In [17]:
args = {"epochs":1,
        "beta":0.1,
        "batch_size":4,
        "lr":1e-5,
        "seed":42,
        "max_length":512,
        "model_name":"meta-llama/Llama-3.2-1B",
        "dataset_name":"jondurbin/truthy-dpo-v0.1",
        "wandb_project":"truthy-dpo"}

seed_everything(args["seed"])
wandb.login()
wandb.init(project=args["wandb_project"], config=args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [6]:
def collate_fn(batch, tokenizer, max_length, device):
  prompts = ['Instruct: ' + item['prompt'] + '\n' for item in batch]
  chosen_responses = ['Output: ' + item['chosen'] for item in batch]
  rejected_responses = ['Output: ' + item['rejected'] for item in batch]

  prompt_ids = tokenizer.batch_encode_plus(prompts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)['input_ids'].to(device)
  prefered_ids = tokenizer.batch_encode_plus(chosen_responses, return_tensors='pt', padding=True, truncation=True, max_length=max_length)['input_ids'].to(device)
  disprefered_ids = tokenizer.batch_encode_plus(rejected_responses, return_tensors='pt', padding=True, truncation=True, max_length=max_length)['input_ids'].to(device)

  prompt_prefered_ids = torch.cat([prompt_ids, prefered_ids], dim=1)
  prompt_disprefered_ids = torch.cat([prompt_ids, disprefered_ids], dim=1)

  prompt_prefered_mask = torch.cat([torch.ones_like(prompt_ids), torch.zeros_like(prefered_ids)], dim=1)
  prompt_disprefered_mask = torch.cat([torch.ones_like(prompt_ids), torch.zeros_like(disprefered_ids)], dim=1)

  return {'prompt_prefered_ids': prompt_prefered_ids,
          'prompt_disprefered_ids': prompt_disprefered_ids,
          'prompt_prefered_mask': prompt_prefered_mask,
          'prompt_disprefered_mask': prompt_disprefered_mask}

In [7]:
def calculate_DPO_losss(model_prefered_logprob, model_disprefered_logprob, ref_prefered_logprob, ref_disprefered_logprob, beta=0.5):
  prefered_relative_logprob = model_prefered_logprob - ref_prefered_logprob
  disprefered_relative_logprob = model_disprefered_logprob - ref_disprefered_logprob

  reward_accuracies = (prefered_relative_logprob > disprefered_relative_logprob).float().mean(dim=-1)
  reward_margins = (prefered_relative_logprob - disprefered_relative_logprob).mean(dim=-1)

  loss = -F.logsigmoid(beta * (prefered_relative_logprob - disprefered_relative_logprob)).mean(dim=-1)
  return loss, prefered_relative_logprob.mean(dim=-1), disprefered_relative_logprob.mean(dim=-1), reward_accuracies, reward_margins

In [8]:
def get_log_prob(logits, labels):
  log_probs = F.log_softmax(logits, dim=-1)
  return torch.gather(log_probs, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1).mean(-1)

In [9]:
def train(model, ref_model, tokenizer, optimizer, train_dataloader, epochs = 1, beta = 0.1):
  model.train()
  ref_model.eval()

  for epoch in range(epochs):
    for batch in tqdm(train_dataloader):
      optimizer.zero_grad()
      prompt_prefered_ids = batch['prompt_prefered_ids']
      prompt_disprefered_ids = batch['prompt_disprefered_ids']
      prompt_prefered_mask = batch['prompt_prefered_mask']
      prompt_disprefered_mask = batch['prompt_disprefered_mask']

      model_prefered_logprobs = get_log_prob(model(prompt_prefered_ids, attention_mask=prompt_prefered_mask).logits, prompt_prefered_ids)
      model_disprefered_logprobs = get_log_prob(model(prompt_disprefered_ids, attention_mask=prompt_disprefered_mask).logits, prompt_disprefered_ids)

      ref_prefered_logprobs = get_log_prob(ref_model(prompt_prefered_ids, attention_mask=prompt_prefered_mask).logits, prompt_prefered_ids)
      ref_disprefered_logprobs = get_log_prob(ref_model(prompt_disprefered_ids, attention_mask=prompt_disprefered_mask).logits, prompt_disprefered_ids)

      loss, prefered_relative_logprob, disprefered_relative_logprob, reward_accuracies, reward_margins = calculate_DPO_losss(model_prefered_logprobs,
                                                                                                                             model_disprefered_logprobs,
                                                                                                                             ref_prefered_logprobs,
                                                                                                                             ref_disprefered_logprobs,
                                                                                                                             beta=beta)
      loss.backward()
      optimizer.step()
      wandb.log({"loss": loss.item(),
                 "prefered_relative_logprob": prefered_relative_logprob,
                 "disprefered_relative_logprob": disprefered_relative_logprob,
                 "reward_accuracies": reward_accuracies,
                 "reward_margins": reward_margins})

In [10]:
HF_token = "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"

In [11]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) Y
Token is valid (permission: fineGrained).
The token `Multiclass` has been saved to /root/.cache/huggingface/stored_tokens
[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-a

In [18]:
tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(args["model_name"]).to(device)
ref_model = AutoModelForCausalLM.from_pretrained(args["model_name"]).to(device)

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

In [19]:
dataset = load_dataset(args["dataset_name"], split="train")
optimizer = AdamW(model.parameters(), lr=args["lr"])

README.md:   0%|          | 0.00/904 [00:00<?, ?B/s]

truthy-dpo.parquet:   0%|          | 0.00/653k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1016 [00:00<?, ? examples/s]

In [20]:
train_dataloader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args["batch_size"],
                                              shuffle=True,
                                              collate_fn=partial(collate_fn, tokenizer=tokenizer, max_length=args["max_length"], device=device))

In [None]:
train(model, ref_model, tokenizer, optimizer, train_dataloader, epochs=args["epochs"], beta=args["beta"])

In [15]:
!nvidia-smi

Thu Jan 23 22:14:22 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   30C    P0              49W / 400W |      5MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [15]:
!pip install blobfile

Collecting blobfile
  Downloading blobfile-3.0.0-py3-none-any.whl.metadata (15 kB)
Collecting pycryptodomex>=3.8 (from blobfile)
  Downloading pycryptodomex-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Downloading blobfile-3.0.0-py3-none-any.whl (75 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.4/75.4 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pycryptodomex-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m80.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pycryptodomex, blobfile
Successfully installed blobfile-3.0.0 pycryptodomex-3.21.0
