In [None]:
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.notebook import tqdm
import sys, os, gc
from IPython.display import clear_output

# mount google drive + link to the correct folder
from google.colab import drive
drive.mount("/content/drive")
sys.path.append("/content/drive/MyDrive/cs329h_project")

from helpers import compute_logprob_and_reply_length_batched as CLRL

In [None]:
# which reference models are we going for (no GEMMA!)
MODEL_DIR = "drive/MyDrive/cs329h_project/models"
REFERENCE_MODELS = [
    m for m in os.listdir(MODEL_DIR)
    if "gemma" not in m
]

# which datasets are we working with?
DATASET_DIR = "drive/MyDrive/cs329h_project/datasets"
DATASETS = {
    "HelpSteer2" : "helpsteer2-preference-v2",
    "UltraFeedback" : "ultrafeedback_binarized",
    "SafeRLHF" : "PKU-SafeRLHF-30K-standard"
}

In [None]:
# we can process pi_refs in batches
BATCH_SIZE = 32

# go thru all of our reference models
for REFERENCE_MODEL in REFERENCE_MODELS[::-1]:

  # load in our reference model + tokenizer
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
  tok = AutoTokenizer.from_pretrained(
      f"{MODEL_DIR}/{REFERENCE_MODEL}", use_fast=True)
  model = AutoModelForCausalLM.from_pretrained(
      f"{MODEL_DIR}/{REFERENCE_MODEL}",
      torch_dtype=DTYPE, device_map="auto").eval()

  # in case the tokenizer doesn't have a pad token
  if tok.pad_token is None:
    tok.pad_token = tok.eos_token
  model.config.pad_token_id = tok.pad_token_id

  # go thru all of our datasets
  for DATASET in DATASETS:

    # go thru all three splits
    for SPLIT in ["train", "val", "test"]:

      # intermediate checkpointing
      fname = f"{REFERENCE_MODEL}_+{DATASET}_{SPLIT}_logprobs.csv"
      if fname in os.listdir("drive/MyDrive/cs329h_project/precomputed"):
        continue

      # create a dataframe to store our results for this model + dataset + split
      columns = ["model", "dataset", "split", "index",
                "logprob_chosen", "L_chosen",
                "logprob_rejected", "L_rejected"]
      logs = pd.DataFrame(data=None, columns=columns)

      # load in the dataframe
      data = pd.read_csv(f"{DATASET_DIR}/{DATASETS[DATASET]}_{SPLIT}.csv")

      # how many batches do we need?
      num_batches = (len(data.index) // BATCH_SIZE) + 1

      # clean house
      clear_output(wait=True)

      # go thru each batch
      for batch_idx in tqdm(range(num_batches), desc=f"{REFERENCE_MODEL} + {DATASET} ({SPLIT})"):

        # get the data for this particular batch
        data_batch = data.loc[
            batch_idx * BATCH_SIZE : ((batch_idx + 1) * BATCH_SIZE) - 1
        ]

        # no gradients, to avoid OOM
        with torch.no_grad():

          # compute the log-probs of the chosen candidates
          logprobs_chosen, L_chosen = CLRL(
              model=model, tok=tok,
              prompts=list(data_batch.prompt.values),
              replies=list(data_batch.chosen.values), MAX_LEN=1024)
          logprobs_chosen = np.array([lp.item() for lp in logprobs_chosen])
          L_chosen = np.array(L_chosen)

          # compute the log-probs of the rejected candidates
          logprobs_rejected, L_rejected = CLRL(
              model=model, tok=tok,
              prompts=list(data_batch.prompt.values),
              replies=list(data_batch.rejected.values), MAX_LEN=1024)
          logprobs_rejected = np.array([lp.item() for lp in logprobs_rejected])
          L_rejected = np.array(L_rejected)

        # add to our dataframe
        batch_logs = pd.DataFrame({"model" : REFERENCE_MODEL, "dataset" : DATASET, "split" : SPLIT, "index" : data_batch.index, "logprob_chosen" : logprobs_chosen, "L_chosen" : L_chosen, "logprob_rejected" : logprobs_rejected, "L_rejected" : L_rejected})
        logs = pd.concat([logs, batch_logs])

      # after we finish this split, memory control
      del data
      del data_batch
      del batch_logs
      gc.collect()

      # save our .csv
      logs.to_csv(f"drive/MyDrive/cs329h_project/precomputed/{fname}", index=False)

      # clean house
      del logs
      gc.collect()

  # memory control again after we are done with this model
  del model
  del tok
  gc.collect()