In [None]:
# --------------------------------------------------------------------------- 
# Main settings for running experiments
# --------------------------------------------------------------------------- 
# must be one of ["PKU-SafeRLHF-30K-standard", "ultrafeedback_binarized"]
DATASET = "ultrafeedback_binarized"
# must be one of [0,1,2,3,4]
SEED = 2
# True for MRPO, False for MDPO
USE_MRPO_OVER_MDPO = True
# must be one of ["online_1", "offline_1", "offline_2", "arwc_normalized"
ALPHA_METHOD = "online_1"


# --------------------------------------------------------------------------- 
# Directory settings
# --------------------------------------------------------------------------- 

# directoriy where data (preference datasets + precomputed log probs)
DATA_DIR = "/content/drive/MyDrive/mypartcs329h/cleaned/"
# directory to save logs
LOGS_DIR = "/content/drive/MyDrive/mypartcs329h/logs/"



# --------------------------------------------------------------------------- 
# Model and training hyperparameters
# --------------------------------------------------------------------------- 
# device to use
DEVICE = 'cuda:0'
# default base model name under HuggingFace
BASE_MODEL_PATH = "Qwen/Qwen2.5-0.5B-Instruct"
# whether to use 4-bit quantization
use_4bit = True
# whether to use LoRA
use_lora = True
# learning rate, beta (for KL penalty), number of epochs
LR, BETA, NUM_EPOCHS = 1e-4, 0.1, 1
# set batch size based on dataset
BATCH_SIZE = 50 if DATASET == "PKU-SafeRLHF-30K-standard" else 25
# compute how many batches we will need
NUM_BATCHES = (5000 // BATCH_SIZE) if (5000 % BATCH_SIZE) == 0 else (5000 // BATCH_SIZE) + 1
# whether to  use micro-batching to fit in GPU memory
use_micro_batch = False
MICRO_BATCH = 5
assert BATCH_SIZE % MICRO_BATCH == 0
NUM_MICRO_BATCHES = BATCH_SIZE // MICRO_BATCH

In [None]:
# This cell installs specific versions of accelerate and bitsandbytes that worked well in our experiments.
! pip install accelerate==1.7.0 bitsandbytes==0.46.0

In [None]:
import sys, os, gc
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW
from peft import LoraConfig, get_peft_model
from peft import prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from tqdm.notebook import tqdm

# Mount Google Drive to get access to data and save logs/models
from google.colab import drive
drive.mount("/content/drive")

sys.path.append("/content/drive/MyDrive/mypartcs329h")
# Import helper functions we implemented
from helpers import CLRL, compute_acc, compute_mrpo_objective, compute_mdpo_objective, compute_dpo_objective_many_refs, expand_vector_columns

In [None]:
# --------------------------------------------------------------------------- 
# Other settings
# --------------------------------------------------------------------------- 

# precomputed aliases for datasets
PRECOMPUTED_ALIASES = {
    "ultrafeedback_binarized" : "UltraFeedback",
    "PKU-SafeRLHF-30K-standard" : "SafeRLHF"
}
# 7 reference models in MRPO/MDPO
REFERENCE_MODELS = [
    "01-ai_Yi-1.5-9B-Chat",
    "meta-llama_Meta-Llama-3.1-8B-Instruct",
    "microsoft_Phi-3-medium-128k-instruct",
    "mistralai_Mistral-7B-Instruct-v0.3",
    "Qwen_Qwen2.5-0.5B-Instruct",
    "Qwen_Qwen2.5-1.5B-Instruct",
    "Qwen_Qwen3-4B-Instruct-2507"
]
# seeds for experiments - corresponds to different data shuffles
SEEDS = [0,1,2,3,4]
# offline_1 = VDW
# offline_2 = VAW
# online_1 = SWCW
# arwc_normalized = original
ALPHA_METHODS = ["offline_1", "offline_2", "online_1", "arwc_normalized"]

# assertions to ensure correct settings
assert ALPHA_METHOD in ALPHA_METHODS, f"ALPHA_METHOD {ALPHA_METHOD} not in {ALPHA_METHODS}"
assert DATASET in PRECOMPUTED_ALIASES.keys(), f"DATASET {DATASET} not in {list(PRECOMPUTED_ALIASES.keys())}"
assert SEED in SEEDS, f"SEED {SEED} not in {SEEDS}"

In [None]:
# --------------------------------------------------------------------------- 
# Load preference data and precomputed log probs
# --------------------------------------------------------------------------- 

# preference data (prompt, chosen, rejected)
train_df = pd.read_csv(f"{DATA_DIR}/{DATASET}/seed={SEED}/data/train.csv")
val_df = pd.read_csv(f"{DATA_DIR}/{DATASET}/seed={SEED}/data/val.csv")
test_df = pd.read_csv(f"{DATA_DIR}/{DATASET}/seed={SEED}/data/test.csv")
# convert to list of values  
val_df_prompts = list(val_df.prompt.values)
val_df_chosen = list(val_df.chosen.values)
val_df_rejected = list(val_df.rejected.values)
test_df_prompts = list(test_df.prompt.values)
test_df_chosen = list(test_df.chosen.values)
test_df_rejected = list(test_df.rejected.values)
# log probs data per split and reference model
log_prob_ref_dict = {}
for SPLIT in ["train", "val", "test"]:
    log_prob_ref_dict[SPLIT] = {}
    for REFERENCE_MODEL in REFERENCE_MODELS:
        df = pd.read_csv(f"{DATA_DIR}/{DATASET}/seed={SEED}/precomputed/{REFERENCE_MODEL}_{SPLIT}.csv")
        df.reset_index(inplace=True)
        log_prob_ref_dict[SPLIT][REFERENCE_MODEL] = df

In [None]:
# --------------------------------------------------------------------------- 
# Load in base model + tokenizer
# --------------------------------------------------------------------------- 

# BitsAndBytes + LoRA configurations
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
    bnb_4bit_compute_dtype=torch.bfloat16,
) if use_4bit else None
target_modules = "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj"
lora_cfg = LoraConfig(
    r=32,
    lora_alpha=16,
    bias="none",
    task_type="CAUSAL_LM",
    init_lora_weights="gaussian",
    target_modules=target_modules.split(","),
) if use_lora else None
# Load in base model + tokenizer 
train_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_PATH,
    device_map=DEVICE,
    quantization_config=bnb_cfg,
    trust_remote_code=True,
)
tok = AutoTokenizer.from_pretrained(BASE_MODEL_PATH, trust_remote_code=True, use_fast=False)
train_model = prepare_model_for_kbit_training(train_model)
train_model = get_peft_model(train_model, lora_cfg)
train_model.train()
torch.cuda.empty_cache()
gc.collect()
train_model.config.use_cache = False
train_model.gradient_checkpointing_enable()
train_model.enable_input_require_grads()
train_model.to(DEVICE)
# set up optimizer and objective function
optimizer = AdamW(train_model.parameters(), lr=LR)
compute_objective = compute_mrpo_objective if USE_MRPO_OVER_MDPO else compute_mdpo_objective

In [None]:
# --------------------------------------------------------------------------- 
# Compute alphas based on ALPHA_METHOD
# --------------------------------------------------------------------------- 
alphas2use = None
if ALPHA_METHOD == 'offline_1':
    # dict mapping ref model to alpha value
    ref2alpha = {}
    for ref_model, log_probs in log_prob_ref_dict['val'].items():
        logprob_chosen = log_probs['logprob_chosen'].values
        length_chosen = log_probs['L_chosen'].values
        logprob_rejected = log_probs['logprob_rejected'].values
        length_rejected = log_probs['L_rejected'].values
        # compute alpha as sum of absolute differences in length-normalized log probs
        alpha = np.sum(np.abs(logprob_chosen/length_chosen - logprob_rejected/length_rejected))
        ref2alpha[ref_model] = alpha
    total_alpha = sum(ref2alpha.values())
    for ref_model in ref2alpha.keys():
        # normalize alpha values so they sum to 1
        ref2alpha[ref_model] = ref2alpha[ref_model] / total_alpha
    # convert to tensor and move to device
    alphas_offline_1 = torch.tensor([ref2alpha[x] for x in REFERENCE_MODELS]).to(DEVICE)
    alphas2use = alphas_offline_1[None,:]
elif ALPHA_METHOD == 'offline_2':
    # dict mapping ref model to alpha value
    ref2alpha = {}
    for ref_model, log_probs in log_prob_ref_dict['val'].items():
        logprob_chosen = log_probs['logprob_chosen'].values
        length_chosen = log_probs['L_chosen'].values
        logprob_rejected = log_probs['logprob_rejected'].values
        length_rejected = log_probs['L_rejected'].values
        # compute alpha as sum of counts where chosen logprob > rejected logprob
        alpha = np.sum((logprob_chosen/length_chosen > logprob_rejected/length_rejected))
        ref2alpha[ref_model] = alpha
    total_alpha = sum(ref2alpha.values())
    for ref_model in ref2alpha.keys():
        # normalize alpha values so they sum to 1
        ref2alpha[ref_model] = ref2alpha[ref_model] / total_alpha
    # convert to tensor and move to device
    alphas_offline_2 = torch.tensor([ref2alpha[x] for x in REFERENCE_MODELS]).to(DEVICE)
    alphas2use = alphas_offline_2[None,:]

In [None]:
# --------------------------------------------------------------------------- 
# Training loop 
# --------------------------------------------------------------------------- 

# dict to log results
logging_results = {'epoch':[], 'batch':[], 'alphas':[], 'batch_acc_pre':[], 'batch_acc_post':[], 'val_acc':[], 'test_acc':[], 'loss': [], 'dpo_by_refs':[]}

# initialize alphas2use_from_prev_batch for online_1 as uniform
alphas2use_from_prev_batch = torch.ones((1, len(REFERENCE_MODELS)), device=DEVICE) / len(REFERENCE_MODELS)

# training loop :
for epoch in tqdm(range(NUM_EPOCHS), desc='epoch'):

    for batch in tqdm(range(NUM_BATCHES), desc='batch'):

        # compute current global step
        global_step = epoch * NUM_BATCHES + batch
        # log epoch and batch
        logging_results['epoch'].append(epoch)
        logging_results['batch'].append(batch)

        # get current batch data
        batch_data = train_df.loc[batch * BATCH_SIZE : ((batch+1) * BATCH_SIZE) - 1]
        # extract prompts, chosen, rejected
        batch_prompt, batch_chosen, batch_rejected = batch_data["prompt"].tolist(), batch_data["chosen"].tolist(), batch_data["rejected"].tolist()

        # stack reference log probs for this batch
        def stack_column(col_name):
            arrs = [
                torch.tensor(
                    data_model.loc[batch_data.index][col_name].values,
                    dtype=torch.float16     
                )
                for _, data_model in log_prob_ref_dict["train"].items()
            ]
            return torch.stack(arrs, dim=0).T.to(DEVICE)
        # preferred and non-preferred log probs and lengths
        preferred_ref_log_probs = stack_column("logprob_chosen")
        nonpreferred_ref_log_probs = stack_column("logprob_rejected")
        L_chosen_refs           = stack_column("L_chosen")
        L_rejected_refs         = stack_column("L_rejected")

        # compute accuracy on this batch BEFORE we do the gradient update!
        with torch.no_grad():
            batch_acc_pre = compute_acc(model=train_model, tok=tok, prompt=batch_prompt, chosen=batch_chosen, rejected=batch_rejected, MAX_BATCH=200, device=DEVICE)
            logging_results['batch_acc_pre'].append(batch_acc_pre.detach().cpu().numpy())

        # compute alphas to use for this batch unless offline
        # original 
        if ALPHA_METHOD == 'arwc_normalized':
            # compute alphas based on normalized absolute differences 
            alphas = torch.abs(preferred_ref_log_probs/L_chosen_refs-nonpreferred_ref_log_probs/L_rejected_refs)
            alphas2use = alphas / torch.sum(alphas, dim=1, keepdim=True) # B x K
        # online 1
        elif ALPHA_METHOD == 'online_1':
            # computed on previous batch
            alphas2use = alphas2use_from_prev_batch # B x K
        # log alphas used this batch
        logging_results['alphas'].append(alphas2use.mean(0).detach().cpu().numpy())

        # training step depending on micro-batching
        if not use_micro_batch :

            # compute the current log-probs on the training batch's preferred.
            preferred_train_log_probs, L_chosen = CLRL(
                model=train_model, tok=tok, prompts=batch_prompt, replies=batch_chosen, device=DEVICE)

            # compute the current log-probs on the training batch's rejected.
            nonpreferred_train_log_probs, L_rejected = CLRL(
                model=train_model, tok=tok, prompts=batch_prompt, replies=batch_rejected, device=DEVICE)

            # B tensors
            preferred_train_log_probs = torch.stack(preferred_train_log_probs)
            nonpreferred_train_log_probs = torch.stack(nonpreferred_train_log_probs)

            # use length normalization
            L_chosen = torch.from_numpy(np.asarray(L_chosen)).to(DEVICE)
            L_rejected = torch.from_numpy(np.asarray(L_rejected)).to(DEVICE)
            # length-normalized log probs both train and ref
            preferred_train_log_probs_2use = preferred_train_log_probs / L_chosen
            nonpreferred_train_log_probs_2use = nonpreferred_train_log_probs / L_rejected
            preferred_ref_log_probs_2use = preferred_ref_log_probs / L_chosen_refs
            nonpreferred_ref_log_probs_2use = nonpreferred_ref_log_probs / L_rejected_refs

            # forward-pass
            loss = compute_objective(
                preferred_train_log_probs_2use,
                nonpreferred_train_log_probs_2use,
                preferred_ref_log_probs_2use,
                nonpreferred_ref_log_probs_2use,
                beta=BETA, alphas=alphas2use)

            # backward-pass
            loss.backward()
            # log loss
            logging_results['loss'].append(loss.item())
            # print loss
            print(f"Epoch {epoch}, Batch {batch}, Loss: {loss.item()}")

            # update our parameters + zero our gradient
            optimizer.step()
            optimizer.zero_grad()

            with torch.no_grad():
                # compute accuracy on this batch after training
                batch_acc_post = compute_acc(model=train_model, tok=tok, prompt=batch_prompt, chosen=batch_chosen, rejected=batch_rejected, MAX_BATCH=200, device=DEVICE)
                logging_results['batch_acc_post'].append(batch_acc_post.detach().cpu().numpy())

                # compute mdpo by ref logs
                mdpo_losses = compute_dpo_objective_many_refs(
                    preferred_train_log_probs_2use,
                    nonpreferred_train_log_probs_2use,
                    preferred_ref_log_probs_2use,
                    nonpreferred_ref_log_probs_2use,
                    beta=BETA) # K
                logging_results['dpo_by_refs'].append(mdpo_losses.detach().cpu().numpy())

        # micro-batching
        else : 
            micro_losses = []
            dpo_by_refs_accum = []
            # ensure alphas2use is B x K
            if alphas2use.shape[0] == 1:
                alphas2use = alphas2use.repeat(BATCH_SIZE, 1)
            # iterate over micro-batches
            for micro_batch in tqdm(range(NUM_MICRO_BATCHES), desc='micro_batch'):
                # get micro-batch data : prompt, chosen, rejected, alphas, ref log probs, lengths
                micro_batch_prompt = batch_prompt[micro_batch*MICRO_BATCH:(micro_batch+1)*MICRO_BATCH]
                micro_batch_chosen = batch_chosen[micro_batch*MICRO_BATCH:(micro_batch+1)*MICRO_BATCH]
                micro_batch_rejected = batch_rejected[micro_batch*MICRO_BATCH:(micro_batch+1)*MICRO_BATCH]
                micro_alphas2use = alphas2use[micro_batch*MICRO_BATCH:(micro_batch+1)*MICRO_BATCH]
                micro_preferred_ref_log_probs = preferred_ref_log_probs[micro_batch*MICRO_BATCH:(micro_batch+1)*MICRO_BATCH]
                micro_nonpreferred_ref_log_probs = nonpreferred_ref_log_probs[micro_batch*MICRO_BATCH:(micro_batch+1)*MICRO_BATCH]
                micro_L_chosen_refs = L_chosen_refs[micro_batch*MICRO_BATCH:(micro_batch+1)*MICRO_BATCH]
                micro_L_rejected_refs = L_rejected_refs[micro_batch*MICRO_BATCH:(micro_batch+1)*MICRO_BATCH]

            # compute the current log-probs on the training batch's preferred.
            micro_preferred_train_log_probs, micro_L_chosen = CLRL(
                model=train_model, tok=tok, prompts=micro_batch_prompt, replies=micro_batch_chosen, device=DEVICE)

            # compute the current log-probs on the training batch's rejected.
            micro_nonpreferred_train_log_probs, micro_L_rejected = CLRL(
                model=train_model, tok=tok, prompts=micro_batch_prompt, replies=micro_batch_rejected, device=DEVICE)

            # B tensors
            micro_preferred_train_log_probs = torch.stack(micro_preferred_train_log_probs)
            micro_nonpreferred_train_log_probs = torch.stack(micro_nonpreferred_train_log_probs)

            # use length normalization
            micro_L_chosen = torch.from_numpy(np.asarray(micro_L_chosen)).to(DEVICE)
            micro_L_rejected = torch.from_numpy(np.asarray(micro_L_rejected)).to(DEVICE)
            micro_preferred_train_log_probs_2use = micro_preferred_train_log_probs / micro_L_chosen
            micro_nonpreferred_train_log_probs_2use = micro_nonpreferred_train_log_probs / micro_L_rejected
            micro_preferred_ref_log_probs_2use = micro_preferred_ref_log_probs / micro_L_chosen_refs
            micro_nonpreferred_ref_log_probs_2use = micro_nonpreferred_ref_log_probs / micro_L_rejected_refs

            # forward-pass
            loss = compute_objective(
                micro_preferred_train_log_probs_2use,
                micro_nonpreferred_train_log_probs_2use,
                micro_preferred_ref_log_probs_2use,
                micro_nonpreferred_ref_log_probs_2use,
                beta=BETA, alphas=micro_alphas2use) / NUM_MICRO_BATCHES

            # backward-pass
            loss.backward()
            micro_losses.append(loss.item())
            # print loss
            print(f"Epoch {epoch}, Batch {batch}, micro_batch {micro_batch}, Loss: {loss.item()}")

            # compute dpo by ref logs for this micro-batch
            with torch.no_grad():
                mb_dpo_losses = compute_dpo_objective_many_refs(
                    micro_preferred_train_log_probs_2use,   # (m)
                    micro_nonpreferred_train_log_probs_2use,
                    micro_preferred_ref_log_probs_2use,
                    micro_nonpreferred_ref_log_probs_2use,
                    beta=BETA
                )
                dpo_by_refs_accum.append(mb_dpo_losses)

            # update our parameters + zero our gradient
            optimizer.step()
            optimizer.zero_grad()
            logging_results['loss'].append(np.sum(micro_losses))

            # aggregate dpo by refs for this batch
            dpo_by_refs_full = torch.vstack(dpo_by_refs_accum)   # (B, K)
            logging_results["dpo_by_refs"].append(
                dpo_by_refs_full.mean(0).detach().cpu().numpy()
            )

            with torch.no_grad():
                # compute accuracy on this batch after training
                batch_acc_post = compute_acc(model=train_model, tok=tok, prompt=batch_prompt, chosen=batch_chosen, rejected=batch_rejected, MAX_BATCH=200, device=DEVICE)
                logging_results['batch_acc_post'].append(batch_acc_post.detach().cpu().numpy())


        # compute alphas for online 1 this batch will be used next batch
        if ALPHA_METHOD == 'online_1': # computed on previous batch
            alphas = torch.abs(preferred_ref_log_probs/L_chosen_refs-nonpreferred_ref_log_probs/L_rejected_refs)
            alphas2use_from_prev_batch = alphas / torch.sum(alphas, dim=1, keepdim=True) # B x K

        # compute val and test acc every 10% of training
        if (((batch+1)% (NUM_BATCHES//10)==0) or (batch == (NUM_BATCHES-1))):
            with torch.no_grad():
                # clear cache before computing val and test accuracy
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect()
                gc.collect()
                # val acc
                val_acc = compute_acc(model=train_model, tok=tok, prompt=val_df_prompts, chosen=val_df_chosen, rejected=val_df_rejected, MAX_BATCH=100, device=DEVICE)
                logging_results['val_acc'].append(val_acc)
                # test acc
                test_acc = compute_acc(model=train_model, tok=tok, prompt=test_df_prompts, chosen=test_df_chosen, rejected=test_df_rejected, MAX_BATCH=100, device=DEVICE)
                logging_results['test_acc'].append(test_acc)
        else:
            val_acc, test_acc = None, None
            logging_results['val_acc'].append(None)
            logging_results['test_acc'].append(None)

        # save intermediate logs every 10% of training
        if (((batch+1)% (NUM_BATCHES//10)==0) or (batch == (NUM_BATCHES-1))):
            df = expand_vector_columns(logging_results, REFERENCE_MODELS)
            fname = LOGS_DIR+f"training_logs_{PRECOMPUTED_ALIASES[DATASET]}_{ALPHA_METHOD}_{SEED}_MRPO.csv" if USE_MRPO_OVER_MDPO else LOGS_DIR+f"training_logs_{PRECOMPUTED_ALIASES[DATASET]}_{ALPHA_METHOD}_{SEED}_MDPO.csv"
            df.to_csv(fname, index=False)

# save final logs after training
df = expand_vector_columns(logging_results, REFERENCE_MODELS)
fname = LOGS_DIR+f"training_logs_{PRECOMPUTED_ALIASES[DATASET]}_{ALPHA_METHOD}_{SEED}_MRPO.csv" if USE_MRPO_OVER_MDPO else LOGS_DIR+f"training_logs_{PRECOMPUTED_ALIASES[DATASET]}_{ALPHA_METHOD}_{SEED}_MDPO.csv"
df.to_csv(fname, index=False)