In [None]:
! pip install bitsandbytes==0.46.0 accelerate==1.7.0

In [None]:
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import get_scheduler
from tqdm.notebook import tqdm
import sys, os, gc
from IPython.display import clear_output
from torch.optim import AdamW
from peft import PeftConfig, PeftModel, LoraConfig, get_peft_model
from peft import prepare_model_for_kbit_training

# mount google drive + link to the correct folder
from google.colab import drive
drive.mount("/content/drive")
sys.path.append("/content/drive/MyDrive/cs329h_project")

# helper functions
from helpers import compute_logprob_and_reply_length_batched as CLRL
from helpers import compute_acc
from helpers import compute_mrpo_objective

device = "cuda" if torch.cuda.is_available() else "cpu"

# settings to loop over
DATASETS = [
    # "helpsteer2-preference-v2",
    "PKU-SafeRLHF-30K-standard",
    "ultrafeedback_binarized"]
REF_MODELS = [
    "01-ai_Yi-1.5-9B-Chat",
    "meta-llama_Meta-Llama-3.1-8B-Instruct",
    "microsoft_Phi-3-medium-128k-instruct",
    "mistralai_Mistral-7B-Instruct-v0.3",
    "Qwen_Qwen2.5-0.5B-Instruct",
    "Qwen_Qwen2.5-1.5B-Instruct",
    "Qwen_Qwen3-4B-Instruct-2507"
    ]
SEEDS = list(range(3))

In [None]:
# keep a counter for checkpointing
counter = 0

# loop through our REF_MODELS, DATASETS, SEEDS
for dataset in DATASETS:
  for ref_model in REF_MODELS:
    for seed in SEEDS: 
      # LOADING IN THE TRAINING MODEL

      ## configuring BitsandBytes (4bit) + LORA (default) settings.
      bnb_cfg = BitsAndBytesConfig(
          load_in_4bit=True, bnb_4bit_quant_type="nf4",
          bnb_4bit_use_double_quant=False, bnb_4bit_compute_dtype=torch.bfloat16)
      target_modules = "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj"
      lora_cfg = LoraConfig(
          r=32, lora_alpha=16, bias="none", task_type="CAUSAL_LM",
          init_lora_weights="gaussian", target_modules=target_modules.split(","))

      ## getting our base model Qwen2.5-0.5B-Instruct
      POLICY_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
      train_model = AutoModelForCausalLM.from_pretrained(
          POLICY_MODEL, device_map="auto", quantization_config=bnb_cfg, trust_remote_code=True)
      train_model.config.use_cache = False
      train_model = prepare_model_for_kbit_training(train_model)
      train_model = get_peft_model(train_model, lora_cfg)
      train_model.train()
      tok = AutoTokenizer.from_pretrained(
          POLICY_MODEL, trust_remote_code=True, use_fast=True)
      torch.cuda.empty_cache()
      gc.collect()

      # some other settings to reduce memory
      train_model.gradient_checkpointing_enable()
      train_model.enable_input_require_grads()

      #### HYPERPARAMETER SETTINGS

      # fixed hyperparameter settings
      BETA, N_EPOCHS, LR = 0.1, 1, 1e-4
      BATCH_SIZE = 50 if dataset == "PKU-SafeRLHF-30K-standard" else 25 # let's use batchsize 25 for ultra-feedback

      # compute how many batches we will need
      NUM_BATCHES = (5000 // BATCH_SIZE) if (5000 % BATCH_SIZE) == 0 else (5000 // BATCH_SIZE) + 1

      # what logs are we trying to record? NO TRAIN_ACC BECAUSE TOO EXPENSIVE
      logs = pd.DataFrame(
          data=None, columns=[
              "epoch", "batch", "batch_acc_pre", "batch_acc_post",
              "val_acc", "test_acc", "dpo_loss"])

      # get our optimizer
      optimizer = AdamW(train_model.parameters(), lr=LR)

      #### DATA LOADING

      # load in the train/val/test splits for this seed + the pre-computed logits
      CLEANED_DIR = f"/content/drive/MyDrive/cs329h_project/cleaned/{dataset}/seed={seed}"

      # the train/val/test data for this split
      train_df = pd.read_csv(f"{CLEANED_DIR}/data/train.csv")[
          ["prompt", "chosen", "rejected"]]
      val_df = pd.read_csv(f"{CLEANED_DIR}/data/val.csv")[
          ["prompt", "chosen", "rejected"]]
      test_df = pd.read_csv(f"{CLEANED_DIR}/data/test.csv")[
          ["prompt", "chosen", "rejected"]]

      # reference train log-probs for this split
      ref_train = pd.read_csv(f"{CLEANED_DIR}/precomputed/{ref_model}_train.csv")

      #### MODEL TRAINING

      # training loop over epochs and batches
      for epoch in range(N_EPOCHS):
        for batch in tqdm(range((NUM_BATCHES))):

          # get this mini-batch worth of data
          batch_data = train_df.loc[batch * BATCH_SIZE : ((batch+1) * BATCH_SIZE) - 1]
          batch_prompt, batch_chosen, batch_rejected = batch_data["prompt"].tolist(), batch_data["chosen"].tolist(), batch_data["rejected"].tolist()

          # compute accuracy on this batch before we do the gradient update
          with torch.no_grad():
            batch_acc_pre = compute_acc(
                model=train_model, tok=tok, prompt=batch_prompt, chosen=batch_chosen, rejected=batch_rejected, MAX_BATCH=50).cpu().item()

          # compute the current log-probs on the training batch's preferred.
          preferred_train_log_probs, L_chosen = CLRL(
              model=train_model, tok=tok, prompts=batch_prompt, replies=batch_chosen, device=device)
          policy_lpm_chosen = torch.stack(preferred_train_log_probs) / torch.tensor(L_chosen, device=device, dtype=torch.float16)

          # compute the current log-probs on the training batch's rejected.
          nonpreferred_train_log_probs, L_rejected = CLRL(
              model=train_model, tok=tok, prompts=batch_prompt, replies=batch_rejected, device=device)
          policy_lpm_rejected = torch.stack(nonpreferred_train_log_probs) / torch.tensor(L_rejected, device=device, dtype=torch.float16)

          # get the logprobs and L of the precomputed reference model
          ref_lp_chosen = torch.tensor(ref_train.loc[batch_data.index]["logprob_chosen"].values, dtype=torch.float16).to(device)
          ref_lp_rejected = torch.tensor(ref_train.loc[batch_data.index]["logprob_rejected"].values, dtype=torch.float16).to(device)
          ref_L_chosen = torch.tensor(ref_train.loc[batch_data.index]["L_chosen"].values, dtype=torch.float16).to(device)
          ref_L_rejected = torch.tensor(ref_train.loc[batch_data.index]["L_rejected"].values, dtype=torch.float16).to(device)
          ref_lpm_chosen = ref_lp_chosen / ref_L_chosen
          ref_lpm_rejected = ref_lp_rejected / ref_L_rejected

          # compute the DPO objective
          loss = torch.mean(
              -torch.nn.functional.logsigmoid(
                  BETA * (
                      (policy_lpm_chosen - ref_lpm_chosen) - \
                      (policy_lpm_rejected - ref_lpm_rejected))
                  )
              )

          # backward-pass
          loss.backward()

          # update our parameters + zero our gradient
          optimizer.step()
          optimizer.zero_grad()

          # compute accuracy on this batch AFTER we did the gradient update
          with torch.no_grad():
            batch_acc_post = compute_acc(
                model=train_model, tok=tok, prompt=batch_prompt, chosen=batch_chosen, rejected=batch_rejected, MAX_BATCH=50).cpu().item()

          # record 10 batches and guarantee the ending.
          if ((batch + 1) % (NUM_BATCHES // 10) == 0) or (batch == (NUM_BATCHES - 1)):

            # compute val and test accuracy too.
            with torch.no_grad():

              # validation set
              val_acc = compute_acc(
                  train_model, tok, val_df.prompt.values, val_df.chosen.values,
                  val_df.rejected.values, MAX_BATCH=50).cpu().item()

              # test set acc
              test_acc = compute_acc(
                  train_model, tok, test_df.prompt.values, test_df.chosen.values,
                  test_df.rejected.values, MAX_BATCH=50).cpu().item()
              print(test_acc)

          else:

            # just put nan's.
            val_acc, test_acc = np.nan, np.nan

          # record our results
          logs.loc[len(logs.index)] = [
              epoch, batch, batch_acc_pre, batch_acc_post,
              val_acc, test_acc, loss.cpu().item()]

          # clean house
          del batch_data, batch_prompt, batch_chosen, batch_rejected, batch_acc_pre, preferred_train_log_probs, L_chosen, policy_lpm_chosen
          del nonpreferred_train_log_probs, L_rejected, policy_lpm_rejected, ref_lp_chosen, ref_lp_rejected, ref_L_chosen, ref_L_rejected, ref_lpm_chosen, ref_lpm_rejected
          del loss, batch_acc_post, val_acc, test_acc
          gc.collect(); torch.cuda.empty_cache()

      # save our logs at the end
      logs.to_csv(
          f"/content/drive/MyDrive/cs329h_project/results/baseline-DPO_dataset={dataset}_rm={ref_model}_seed={seed}.csv", index=False)

      #### CLEAN HOUSE + STATUS UPDATE

      # update our counter
      counter += 1

      # clear memory
      del bnb_cfg, target_modules, lora_cfg, train_model, tok, logs, optimizer, train_df, val_df, test_df, ref_train
      gc.collect(); torch.cuda.empty_cache()

      # clear output too
      clear_output(wait=True)