In [None]:
! pip install bitsandbytes==0.46.0 accelerate==1.7.0

In [None]:
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import get_scheduler
from tqdm.notebook import tqdm
import sys, os, gc
from IPython.display import clear_output
from torch.optim import AdamW
from peft import PeftConfig, PeftModel, LoraConfig, get_peft_model
from peft import prepare_model_for_kbit_training

# mount google drive + link to the correct folder
from google.colab import drive
drive.mount("/content/drive")
sys.path.append("/content/drive/MyDrive/cs329h_project")

# helper functions
from helpers import compute_logprob_and_reply_length_batched as CLRL
from helpers import compute_acc
from helpers import compute_mrpo_objective
from helpers import compute_dpo_objective_many_refs

device = "cuda" if torch.cuda.is_available() else "cpu"


# we will try 2x datasets, 3x seeds, with 7x potential reference models
DATASETS = [
    "PKU-SafeRLHF-30K-standard",
    "ultrafeedback_binarized"]
REF_MODELS = [
    "01-ai_Yi-1.5-9B-Chat",
    "meta-llama_Meta-Llama-3.1-8B-Instruct",
    "microsoft_Phi-3-medium-128k-instruct",
    "mistralai_Mistral-7B-Instruct-v0.3",
    "Qwen_Qwen2.5-0.5B-Instruct",
    "Qwen_Qwen2.5-1.5B-Instruct",
    "Qwen_Qwen3-4B-Instruct-2507"
    ]
N_REF_MODELS = len(REF_MODELS)

In [None]:
# keep a counter for checkpointing
counter = 0

# loop through prior_init_vals, datasets, and seeds
for dataset in DATASETS:

  # different numbers of seeds for each dataset
  SEEDS = range(5) if dataset == "PKU-SafeRLHF-30K-standard" else range(3)

  # go thru our seeds
  for seed in SEEDS:

    #### LOADING IN THE TRAINING MODEL

    ## configuring BitsandBytes (4bit) + LORA (default) settings.
    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=False, bnb_4bit_compute_dtype=torch.bfloat16)
    target_modules = "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj"
    lora_cfg = LoraConfig(
        r=32, lora_alpha=16, bias="none", task_type="CAUSAL_LM",
        init_lora_weights="gaussian", target_modules=target_modules.split(","))

    ## getting our base model Qwen2.5-0.5B-Instruct
    POLICY_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
    train_model = AutoModelForCausalLM.from_pretrained(
        POLICY_MODEL, device_map="auto", quantization_config=bnb_cfg, trust_remote_code=True)
    train_model.config.use_cache = False
    train_model = prepare_model_for_kbit_training(train_model)
    train_model = get_peft_model(train_model, lora_cfg)
    train_model.train()
    tok = AutoTokenizer.from_pretrained(
        POLICY_MODEL, trust_remote_code=True, use_fast=True)
    torch.cuda.empty_cache()
    gc.collect()

    # some other settings to reduce memory
    train_model.gradient_checkpointing_enable()
    train_model.enable_input_require_grads()

    #### HYPERPARAMETER SETTINGS

    # fixed hyperparameter settings
    BETA, N_EPOCHS, LR = 0.1, 1, 1e-4
    BATCH_SIZE = 50 if dataset == "PKU-SafeRLHF-30K-standard" else 25

    # compute how many batches we will need
    NUM_BATCHES = (5000 // BATCH_SIZE) if (5000 % BATCH_SIZE) == 0 else (5000 // BATCH_SIZE) + 1

    # what logs are we trying to record? NO TRAIN_ACC BECAUSE TOO EXPENSIVE
    logs = pd.DataFrame(
        data=None, columns=[
            "epoch", "batch",
            "batch_acc_pre", "batch_acc_post",
            "chosen_dpo_loss", "chosen_reference",
            "dpo0", "dpo1", "dpo2", "dpo3", "dpo4", "dpo5", "dpo6",
            "val_acc", "test_acc"])

    # get our optimizer
    optimizer = AdamW(train_model.parameters(), lr=LR)

    #### DATA LOADING

    # load in the train/val/test splits for this seed + the pre-computed logits
    CLEANED_DIR = f"/content/drive/MyDrive/cs329h_project/cleaned/{dataset}/seed={seed}"

    # the train/val/test data for this split
    train_df = pd.read_csv(f"{CLEANED_DIR}/data/train.csv")[
        ["prompt", "chosen", "rejected"]]
    val_df = pd.read_csv(f"{CLEANED_DIR}/data/val.csv")[
        ["prompt", "chosen", "rejected"]]
    test_df = pd.read_csv(f"{CLEANED_DIR}/data/test.csv")[
        ["prompt", "chosen", "rejected"]]

    # load in all seven models' reference log-probs for this train split
    all_ref_train = {
        i : pd.read_csv(f"{CLEANED_DIR}/precomputed/{ref_model}_train.csv")
        for (i, ref_model) in enumerate(REF_MODELS)}

    #### MODEL TRAINING

    # set a seed for reproducibility on the Thompson Sampling
    np.random.seed(seed)

    # initialize the prev_sliding window weights for Offline 1
    prev_sliding_window_weights = None

    # training loop over epochs and batches
    for epoch in range(N_EPOCHS):
      for batch in tqdm(range((NUM_BATCHES))):

        # get this mini-batch worth of data
        batch_data = train_df.loc[batch * BATCH_SIZE : ((batch+1) * BATCH_SIZE) - 1]
        batch_prompt, batch_chosen, batch_rejected = batch_data["prompt"].tolist(), batch_data["chosen"].tolist(), batch_data["rejected"].tolist()

        # compute accuracy on this batch before we do the gradient update
        with torch.no_grad():
          batch_acc_pre = compute_acc(
              model=train_model, tok=tok, prompt=batch_prompt, chosen=batch_chosen, rejected=batch_rejected, MAX_BATCH=50).cpu().item()

        # which reference model are we going to be aligning to?
        if prev_sliding_window_weights is not None:
          chosen_reference = np.argmax(prev_sliding_window_weights)
        else:
          chosen_reference = 4 # Qwen-2.5-0.5B-Instruct

        # lists to store our snippets of policy_lpm_chosen and rejected, also our loss
        policy_lpm_chosen, policy_lpm_rejected = [], []
        loss = 0.0

        # micro-batch in sizes of 13
        MICRO_BATCH_SIZE = 16
        N_MICRO_BATCHES = (BATCH_SIZE // MICRO_BATCH_SIZE) if BATCH_SIZE % MICRO_BATCH_SIZE == 0 else (BATCH_SIZE // MICRO_BATCH_SIZE) + 1
        for m_batch in range(N_MICRO_BATCHES):

          # what are the indices for this mini-batch?
          m_batch_start, m_batch_end = m_batch * MICRO_BATCH_SIZE, np.minimum(BATCH_SIZE, (m_batch + 1) * MICRO_BATCH_SIZE)

          # compute the current log-probs on the training batch's preferred.
          preferred_train_log_probs_m, L_chosen_m = CLRL(
              model=train_model, tok=tok,
              prompts=batch_prompt[m_batch_start : m_batch_end],
              replies=batch_chosen[m_batch_start : m_batch_end], device=device)
          policy_lpm_chosen_m = torch.stack(preferred_train_log_probs_m) / torch.tensor(L_chosen_m, device=device, dtype=torch.float16)

          # compute the current log-probs on the training batch's rejected.
          nonpreferred_train_log_probs_m, L_rejected_m = CLRL(
              model=train_model, tok=tok,
              prompts=batch_prompt[m_batch_start : m_batch_end],
              replies=batch_rejected[m_batch_start : m_batch_end], device=device)
          policy_lpm_rejected_m = torch.stack(nonpreferred_train_log_probs_m) / torch.tensor(L_rejected_m, device=device, dtype=torch.float16)

          # get the logprobs and L of the CHOSEN precomputed reference model
          ref_lp_chosen_m = torch.tensor(
              all_ref_train[chosen_reference].loc[batch_data.index]["logprob_chosen"].values[m_batch_start : m_batch_end], dtype=torch.float16).to(device)
          ref_lp_rejected_m = torch.tensor(
              all_ref_train[chosen_reference].loc[batch_data.index]["logprob_rejected"].values[m_batch_start : m_batch_end], dtype=torch.float16).to(device)
          ref_L_chosen_m = torch.tensor(
              all_ref_train[chosen_reference].loc[batch_data.index]["L_chosen"].values[m_batch_start : m_batch_end], dtype=torch.float16).to(device)
          ref_L_rejected_m = torch.tensor(
              all_ref_train[chosen_reference].loc[batch_data.index]["L_rejected"].values[m_batch_start : m_batch_end], dtype=torch.float16).to(device)
          ref_lpm_chosen_m = ref_lp_chosen_m / ref_L_chosen_m
          ref_lpm_rejected_m = ref_lp_rejected_m / ref_L_rejected_m

          # compute the DPO objective
          loss_m = torch.mean(
              -torch.nn.functional.logsigmoid(
                  BETA * (
                      (policy_lpm_chosen_m - ref_lpm_chosen_m) - \
                      (policy_lpm_rejected_m - ref_lpm_rejected_m))
                  )
              ) * (m_batch_end - m_batch_start) / BATCH_SIZE

          # backward-pass
          loss_m.backward()

          # add detached versions of policy_lpm_chosen_m, policy_lpm_rejected_m
          policy_lpm_chosen.append(policy_lpm_chosen_m.detach())
          policy_lpm_rejected.append(policy_lpm_rejected_m.detach())
          loss += loss_m.detach()

        # update our parameters + zero our gradient for the full batch!
        optimizer.step()
        optimizer.zero_grad()

        # get the full tensors
        policy_lpm_chosen = torch.cat(policy_lpm_chosen, dim=0)
        policy_lpm_rejected = torch.cat(policy_lpm_rejected, dim=0)

        # compute accuracy on this batch AFTER we did the gradient update
        with torch.no_grad():
          batch_acc_post = compute_acc(
              model=train_model, tok=tok, prompt=batch_prompt, chosen=batch_chosen, rejected=batch_rejected, MAX_BATCH=50).cpu().item()

        # record 10 batches and guarantee the ending.
        if ((batch + 1) % (NUM_BATCHES // 10) == 0) or (batch == (NUM_BATCHES - 1)):

          # compute val and test accuracy too.
          with torch.no_grad():

            # validation set
            val_acc = compute_acc(
                train_model, tok, val_df.prompt.values, val_df.chosen.values,
                val_df.rejected.values, MAX_BATCH=50).cpu().item()

            # test set acc
            test_acc = compute_acc(
                train_model, tok, test_df.prompt.values, test_df.chosen.values,
                test_df.rejected.values, MAX_BATCH=50).cpu().item()
            print(test_acc)

        else:

          # just put nan's.
          val_acc, test_acc = np.nan, np.nan

        # also record per-reference model DPO losses
        all_ref_lp_chosen = torch.tensor(
            np.array([all_ref_train[i].loc[batch_data.index]["logprob_chosen"].values
                      for i in range(N_REF_MODELS)]).T, dtype=torch.float16).to(device)
        all_ref_lp_rejected = torch.tensor(
            np.array([all_ref_train[i].loc[batch_data.index]["logprob_rejected"].values
                      for i in range(N_REF_MODELS)]).T, dtype=torch.float16).to(device)
        all_ref_L_chosen = torch.tensor(
            np.array([all_ref_train[i].loc[batch_data.index]["L_chosen"].values
                      for i in range(N_REF_MODELS)]).T, dtype=torch.float16).to(device)
        all_ref_L_rejected = torch.tensor(
            np.array([all_ref_train[i].loc[batch_data.index]["L_rejected"].values
                      for i in range(N_REF_MODELS)]).T, dtype=torch.float16).to(device)
        all_ref_lpm_chosen = all_ref_lp_chosen / all_ref_L_chosen
        all_ref_lpm_rejected = all_ref_lp_rejected / all_ref_L_rejected
        with torch.no_grad():
          per_reference_dpo_loss = compute_dpo_objective_many_refs(
              policy_lpm_chosen, policy_lpm_rejected,
              all_ref_lpm_chosen, all_ref_lpm_rejected, BETA)

        # record our results
        logs.loc[len(logs.index)] = [
            epoch, batch,
            batch_acc_pre, batch_acc_post,
            loss.cpu().item(), chosen_reference]\
              + list(per_reference_dpo_loss.cpu().numpy())\
              + [val_acc, test_acc]

        # compute our previous sliding window weights for Offline 1.
        prev_sliding_window_weights = torch.abs(all_ref_lpm_chosen - all_ref_lpm_rejected).mean(dim=0).cpu().numpy()

        # clean house
        del batch_data, batch_prompt, batch_chosen, batch_rejected, batch_acc_pre
        del preferred_train_log_probs_m, L_chosen_m, policy_lpm_chosen_m, policy_lpm_chosen
        del nonpreferred_train_log_probs_m, L_rejected_m, policy_lpm_rejected_m, policy_lpm_rejected
        del chosen_reference, ref_lp_chosen_m, ref_lp_rejected_m, ref_L_chosen_m, ref_L_rejected_m, ref_lpm_chosen_m, ref_lpm_rejected_m
        del loss_m, loss, batch_acc_post, val_acc, test_acc
        del all_ref_lp_chosen, all_ref_lp_rejected, all_ref_L_chosen, all_ref_L_rejected, all_ref_lpm_chosen, all_ref_lpm_rejected
        del per_reference_dpo_loss
        gc.collect(); torch.cuda.empty_cache()

    # save our logs at the end
    logs.to_csv(
      f"/content/drive/MyDrive/cs329h_project/results/Offline-1-One-Hot_dataset={dataset}_seed={seed}.csv", index=False)

    #### CLEAN HOUSE + STATUS UPDATE

    # update our counter
    counter += 1

    # clear memory
    del bnb_cfg, target_modules, lora_cfg, train_model, tok, logs, optimizer
    del train_df, val_df, test_df, all_ref_train
    gc.collect(); torch.cuda.empty_cache()

    # clear output too
    clear_output(wait=True)