In [1]:
import wandb

from dataclasses import dataclass, field
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardConfig, RewardTrainer
from peft import LoraConfig, TaskType # Parameter Efficient Fine Tuning
from tqdm import tqdm

import pandas as pd
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tqdm.pandas()

In [3]:
# Log in to Weights and Biases for training logging
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtrevorashby[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
DATASET_TRAIN = pd.read_feather("../mini_codenet/data/split/reward_train.ftr")
DATASET_EVAL = pd.read_feather("../mini_codenet/data/split/reward_val.ftr")

In [5]:
DATASET_TRAIN.head()

Unnamed: 0,level_0,index,submission_id,problem_id,language,filename_ext,status,cpu_time,memory,code_size,accuracy,solution,problem_statement
0,732215,1955,s679316564,p02629,Python,py,Accepted,27.0,9096.0,166,,"n = int(input())\nc = "" ""\nwhile n:\n if n ...",Score : 300 points \n Problem Statement 100000...
1,2873817,1241,s424811397,p03700,C++,cpp,Accepted,91.0,1024.0,711,,#include<iostream>\n#include<queue>\nusing nam...,Score : 400 points \n Problem Statement You ar...
2,2117312,2364,s962040422,p02904,C++,cpp,Wrong Answer,17.0,14336.0,2852,,#include <bits/stdc++.h>\n//#include <tr1/unor...,Score : 700 points \n Problem Statement Snuke ...
3,494734,7691,s601052755,p02945,C++,cpp,Accepted,1.0,256.0,147,,"#include ""bits/stdc++.h""\n\nusing namespace st...",Score : 100 points \n Problem Statement We hav...
4,2040572,8188,s811638706,p03574,Python,py,Accepted,27.0,3188.0,486,,"h, w = map(int, input().split())\ns = [[""a""]*(...",Score : 200 points \n Problem Statement You ar...


In [6]:
print(len(DATASET_TRAIN))
DATASET_TRAIN.groupby("status")["solution"].count()

666324


status
Accepted                  358603
Compile Error              24979
Memory Limit Exceeded        542
Output Limit Exceeded         48
Query Limit Exceeded           2
Runtime Error              41398
Time Limit Exceeded        36160
WA: Presentation Error      2763
Wrong Answer              201829
Name: solution, dtype: int64

In [7]:
DATASET_EVAL.head()

Unnamed: 0,level_0,index,submission_id,problem_id,language,filename_ext,status,cpu_time,memory,code_size,accuracy,solution,problem_statement
0,2360322,5322,s850061551,p03288,C++,cpp,Accepted,1.0,256.0,213,,#include <bits/stdc++.h>\nusing namespace std;...,Score : 100 points \n Problem Statement A prog...
1,2047495,7500,s837081478,p03548,C,c,Accepted,1.0,128.0,183,,#include<stdio.h>\n#include<string.h>\n#includ...,Score : 200 points \n Problem Statement We hav...
2,119655,7092,s828220597,p03308,C++,cpp,Compile Error,,,406,,#include <iostream>\n#include <string>\n#inclu...,Score : 200 points \n Problem Statement You ar...
3,1327179,1296,s282902347,p02924,Python,py,Wrong Answer,28.0,9008.0,54,,n = int(input())\nans = int(1/2 * n * (n-1))\n...,Score : 400 points \n Problem Statement For an...
4,2153223,1499,s083246489,p03944,Python,py,Wrong Answer,17.0,3064.0,323,,"w, h, n = map(int, input().split())\nx_reg = 0...",Score : 200 points \n Problem Statement There ...


In [8]:
print(len(DATASET_EVAL))
DATASET_EVAL.groupby("status")["solution"].count()

190379


status
Accepted                  102749
Compile Error               7135
Memory Limit Exceeded        156
Output Limit Exceeded          8
Query Limit Exceeded           1
Runtime Error              11691
Time Limit Exceeded        10223
WA: Presentation Error       747
Wrong Answer               57669
Name: solution, dtype: int64

In [9]:
# Sample 1000 accepted solutions at random.
accepted_train = DATASET_TRAIN[DATASET_TRAIN["status"] == "Accepted"][["submission_id", "problem_id", "language", "solution"]]
rejected_train = DATASET_TRAIN[DATASET_TRAIN["status"] != "Accepted"][["submission_id", "problem_id", "language", "solution"]]
accepted_eval = DATASET_EVAL[DATASET_EVAL["status"] == "Accepted"][["submission_id", "problem_id", "language", "solution"]]
rejected_eval = DATASET_EVAL[DATASET_EVAL["status"] != "Accepted"][["submission_id", "problem_id", "language", "solution"]]

print("Total Accepted Problems in TRAIN:", len(accepted_train["submission_id"]))
print("Total Rejected Problems in TRAIN:", len(rejected_train["submission_id"]))
print("Unique IDs in Accepted TRAIN:", len(accepted_train["problem_id"].unique()))
print("Unique IDs in Rejected TRAIN:", len(rejected_train["problem_id"].unique()))
print("------------")
print("Total Accepted Problems in EVAL:", len(accepted_eval["submission_id"]))
print("Total Rejected Problems in EVAL:", len(rejected_eval["submission_id"]))
print("Unique IDs in Accepted EVAL:", len(accepted_eval["problem_id"].unique()))
print("Unique IDs in Rejected EVAL:", len(rejected_eval["problem_id"].unique()))

Total Accepted Problems in TRAIN: 358603
Total Rejected Problems in TRAIN: 307721
Unique IDs in Accepted TRAIN: 2450
Unique IDs in Rejected TRAIN: 2359
------------
Total Accepted Problems in EVAL: 102749
Total Rejected Problems in EVAL: 87630
Unique IDs in Accepted EVAL: 2188
Unique IDs in Rejected EVAL: 2124


In [14]:
# For each accepted solution, chose a contrasting rejected 
def get_contrastive_pairs(data_accepted, data_rejected, n=3):
    data = { "accepted": [], "rejected": [] }

    # SPEED UP!! Group rejected answers by problem_id and language and cache the results so
    # we do not have to filter the whole dataset inside the main for-loop on every iteration.
    # Plus, we get O(1) look up time 😎
    grouped_rejected = data_rejected.groupby(["problem_id", "language"])["solution"].apply(list).to_dict()

    for _, accepted_pid, accepted_lang, accepted_sol in tqdm(data_accepted.values):
        key = (accepted_pid, accepted_lang)

        if key in grouped_rejected:
            # Get up to `n`` rejected examples in the current language for the current problem.
            rejected_filtered = grouped_rejected[key]

            size = min(len(rejected_filtered), n)
            for idx in np.random.randint(0, len(rejected_filtered), size):
                data["accepted"].append(accepted_sol)
                data["rejected"].append(rejected_filtered[idx])
        else:
            # The problem only contains a correct solutions in the current language. Skip it.
            pass

    return Dataset.from_dict(data)

# Tokenize chosen/rejected pairs of inputs
def preprocess_function(examples, tokenizer):
    new_examples = {
        "input_ids_chosen": [],
        "attention_mask_chosen": [],
        "input_ids_rejected": [],
        "attention_mask_rejected": [],
    }

    for chosen, rejected in zip(examples["accepted"], examples["rejected"]):
        tokenized_chosen = tokenizer(chosen)
        tokenized_rejected = tokenizer(rejected)

        new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
        new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
        new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
        new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])

    return new_examples

# Preprocess the dataset and filter out examples that are longer than args.max_length
def process_data(accepted, rejected, tokenizer, args):
    dataset = get_contrastive_pairs(accepted, rejected)

    dataset = dataset.map(
        lambda example: preprocess_function(example, tokenizer),
        batched=True,
        #num_proc=4,
    )

    dataset = dataset.filter(
        lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
        and len(x["input_ids_rejected"]) <= args.reward_config.max_length
    )
    
    return dataset

In [11]:
@dataclass
class ScriptArguments:
    model_name: str = "../finetuning/hf_model/" # TODO: Change path to correct SFT model
    """the model name"""
    eval_split: bool = False
    """the dataset split to evaluate on; default to 'none' (no evaluation)"""
    reward_config: RewardConfig = field(
        default_factory=lambda: RewardConfig(
            output_dir="output",
            per_device_train_batch_size=64,
            num_train_epochs=10,
            gradient_accumulation_steps=16,
            gradient_checkpointing=True,
            gradient_checkpointing_kwargs={"use_reentrant": False},
            learning_rate=1.41e-5,
            report_to="wandb", # log training progress to Weights and Biases
            remove_unused_columns=False,
            optim="adamw_torch",
            logging_steps=500,
            evaluation_strategy="no",
            max_length=256, # TODO: NEED TO CHANGE THIS!
        )
    )

args = ScriptArguments()
args.reward_config.evaluation_strategy = "steps" if args.eval_split else "no"

In [12]:
# Step 1: Load the dataset and pre-process it
tokenizer = AutoTokenizer.from_pretrained(args.model_name)

In [15]:
print("Training Data:")
train_dataset = process_data(accepted_train, rejected_train, tokenizer, args)

print("\nEvaluation Data:")
eval_dataset = process_data(accepted_eval, rejected_eval, tokenizer, args)

Training Data:


100%|██████████| 358603/358603 [00:10<00:00, 33452.72it/s]
Map: 100%|██████████| 1068246/1068246 [43:46<00:00, 406.79 examples/s] 
Filter: 100%|██████████| 1068246/1068246 [22:57<00:00, 775.22 examples/s]



Evaluation Data:


100%|██████████| 102749/102749 [00:03<00:00, 31615.77it/s]
Map:  93%|█████████▎| 282000/301682 [12:44<00:53, 369.01 examples/s]


KeyboardInterrupt: 

In [None]:
# Step 2: Load the model
model = AutoModelForSequenceClassification.from_pretrained(args.model_name, num_labels=1)
model.config.pad_token_id = tokenizer.pad_token_id

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

In [None]:
# Step 4: Define the Trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    args=args.reward_config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config
)

trainer.train()