In [None]:
import wandb

from dataclasses import dataclass, field
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardConfig, RewardTrainer
from peft import LoraConfig, TaskType # Parameter Efficient Fine Tuning
from tqdm import tqdm

import pandas as pd

In [None]:
tqdm.pandas()

In [None]:
# Log in to Weights and Biases for training logging
wandb.login()

In [None]:
DATASET_TRAIN = pd.read_feather("../mini_codenet/data/split/reward_train.ftr")
DATASET_EVAL = pd.read_feather("../mini_codenet/data/split/reward_val.ftr")

In [None]:
DATASET_TRAIN.head()

In [None]:
print(len(DATASET_TRAIN))
DATASET_TRAIN.groupby("status")["solution"].count()

In [None]:
DATASET_EVAL.head()

In [None]:
print(len(DATASET_EVAL))
DATASET_EVAL.groupby("status")["solution"].count()

In [None]:
# Sample 1000 accepted solutions at random.
accepted_train = DATASET_TRAIN[DATASET_TRAIN["status"] == "Accepted"][["submission_id", "problem_id", "language", "solution"]]
rejected_train = DATASET_TRAIN[DATASET_TRAIN["status"] != "Accepted"][["submission_id", "problem_id", "language", "solution"]]
accepted_eval = DATASET_EVAL[DATASET_EVAL["status"] == "Accepted"][["submission_id", "problem_id", "language", "solution"]]
rejected_eval = DATASET_EVAL[DATASET_EVAL["status"] != "Accepted"][["submission_id", "problem_id", "language", "solution"]]

print("Total Accepted Problems in TRAIN:", len(accepted_train["submission_id"]))
print("Total Rejected Problems in TRAIN:", len(rejected_train["submission_id"]))
print("Unique IDs in Accepted TRAIN:", len(accepted_train["problem_id"].unique()))
print("Unique IDs in Rejected TRAIN:", len(rejected_train["problem_id"].unique()))
print("------------")
print("Total Accepted Problems in EVAL:", len(accepted_eval["submission_id"]))
print("Total Rejected Problems in EVAL:", len(rejected_eval["submission_id"]))
print("Unique IDs in Accepted EVAL:", len(accepted_eval["problem_id"].unique()))
print("Unique IDs in Rejected EVAL:", len(rejected_eval["problem_id"].unique()))

In [None]:
# For each accepted solution, chose a contrasting rejected 
def get_contrastive_pairs(data_accepted, data_rejected, n=10):
    data = { "accepted": [], "rejected": [] }

    # SPEED UP!! Group rejected answers by problem_id and language and cache the results so
    # we do not have to filter the whole dataset inside the main for-loop on every iteration.
    # Plus, we get O(1) look up time 😎
    grouped_rejected = data_rejected.groupby(["problem_id", "language"])["solution"].apply(list).to_dict()

    for _, accepted_pid, accepted_lang, accepted_sol in tqdm(data_accepted.values):
        key = (accepted_pid, accepted_lang)

        if key in grouped_rejected:
            # Get up to `n`` rejected examples in the current language for the current problem.
            rejected_filtered = grouped_rejected[key][:n]

            for rejected_sol in rejected_filtered:
                data["accepted"].append(accepted_sol)
                data["rejected"].append(rejected_sol)
        else:
            # The problem only contains a correct solutions in the current language. Skip it.
            pass

    return Dataset.from_dict(data)

# Tokenize chosen/rejected pairs of inputs
def preprocess_function(examples, tokenizer):
    new_examples = {
        "input_ids_chosen": [],
        "attention_mask_chosen": [],
        "input_ids_rejected": [],
        "attention_mask_rejected": [],
    }

    for chosen, rejected in zip(examples["accepted"], examples["rejected"]):
        tokenized_chosen = tokenizer(chosen)
        tokenized_rejected = tokenizer(rejected)

        new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
        new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
        new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
        new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])

    return new_examples

# Preprocess the dataset and filter out examples that are longer than args.max_length
def process_data(accepted, rejected, tokenizer, args):
    dataset = get_contrastive_pairs(accepted, rejected)

    dataset = dataset.map(
        lambda example: preprocess_function(example, tokenizer),
        batched=True,
        num_proc=4,
    )

    dataset = dataset.filter(
        lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
        and len(x["input_ids_rejected"]) <= args.reward_config.max_length
    )
    
    return dataset

In [None]:
@dataclass
class ScriptArguments:
    model_name: str = "../hf_model/" # TODO: Change path to correct SFT model
    """the model name"""
    eval_split: bool = False
    """the dataset split to evaluate on; default to 'none' (no evaluation)"""
    reward_config: RewardConfig = field(
        default_factory=lambda: RewardConfig(
            output_dir="output",
            per_device_train_batch_size=64,
            num_train_epochs=10,
            gradient_accumulation_steps=16,
            gradient_checkpointing=True,
            gradient_checkpointing_kwargs={"use_reentrant": False},
            learning_rate=1.41e-5,
            report_to="wandb", # log training progress to Weights and Biases
            remove_unused_columns=False,
            optim="adamw_torch",
            logging_steps=500,
            evaluation_strategy="no",
            max_length=256, # TODO: NEED TO CHANGE THIS!
        )
    )

args = ScriptArguments()
args.reward_config.evaluation_strategy = "steps" if args.eval_split else "no"

In [None]:
# Step 1: Load the dataset and pre-process it
tokenizer = AutoTokenizer.from_pretrained(args.model_name)

In [None]:
print("Training Data:")
train_dataset = process_data(accepted_train, rejected_train, tokenizer, args)

print("\nEvaluation Data:")
eval_dataset = process_data(accepted_eval, rejected_eval, tokenizer, args)

In [None]:
# Step 2: Load the model
model = AutoModelForSequenceClassification.from_pretrained(args.model_name, num_labels=1)
model.config.pad_token_id = tokenizer.pad_token_id

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

In [None]:
# Step 4: Define the Trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    args=args.reward_config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config
)

trainer.train()