# Setup

## Hyperparameters + Data Tracking

In [1]:
################# HYPERPARAMETERS #########################
model_name       = "Qwen/Qwen2.5-1.5B-Instruct"

#Optimizer
lr               = 1e-5
weight_decay     = 0.01

# Training + Dataset
num_epochs       = 10
batch_size       = 10
validation_split = 0.3
max_length       = 512
# NOTE: Vocab size of model is 151936

# GRPO
num_samples      = 5
eps              = -1 # TODO: Update values
beta             = -1 # TODO: Update values
mu               = -1 # TODO: Update values
################# HYPERPARAMETERS #########################

# Initialize wandb
import wandb
wandb.init(project="qwen-finetuning", name="grpo-finetune-run", config={
    "learning_rate": lr,
    "batch_size": batch_size,
    "epochs": num_epochs,
    "model_name": model_name
})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mcormaccureton[0m ([33mcormaccureton-mcgill-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Generating Dataset and Dataloaders

In [None]:
import torch, tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataset.countdown_dataloader import Countdown
from dataset.countdown_dataloader import Countdown
from grpo import *
from torch.utils.data import Dataset, DataLoader
from dataset.countdown_utils import (
    gen_dataset,
    compute_metrics
)

# Creates the dataset and saves it into a json file
dataset_json_path = "data/countdown_data.json"
_ = gen_dataset(num_samples=100, save_path=dataset_json_path, num_operands = 4)

countdown_data = Countdown(json_path=dataset_json_path)

# Generates the prompts and the training and validation dataloaders
def create_prompts(queries: dict, model_type: str = 'base'):
    prompts = []

    if model_type == 'base':
        preamble = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
        User: Using the numbers {}, create an equation that equals {}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.
        Assistant: Let me solve this step by step.
        <think>"""
    else:
        preamble = """<|im_start|>system\nYou are a helpful assistant. You first think about the reasoning process in the mind and then provides the user with the answer.<|im_end|>\n<|im_start|>user\n Using the numbers {}, create an equation that equals {}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.<|im_end|>\n<|im_start|>assistant\nLet me solve this step by step.\n<think>"""

    for query in queries:
        prompts.append(preamble.format(query['numbers'], query['target']))
    return prompts

prompts = create_prompts(countdown_data)

split_size     = int(validation_split*len(prompts))
training_set   = prompts[:-split_size]
validation_set = prompts[-split_size:]

class tokenized_prompt_dataset(Dataset):
    def __init__(self, prompts, tokenizer, max_length=256):
        self.prompts    = prompts
        self.tokenizer  = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        prompt = self.prompts[idx]
        inputs = self.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)
        return {key: value.squeeze(0) for key, value in inputs.items()}  # Remove batch dim

# Create dataset and dataloaders
tokenizer                = AutoTokenizer.from_pretrained(model_name)
tokenized_training_set   = tokenized_prompt_dataset(training_set, tokenizer, max_length = max_length)
training_dataloader      = DataLoader(tokenized_training_set, batch_size=batch_size, shuffle=True)

tokenized_validation_set = tokenized_prompt_dataset(validation_set, tokenizer, max_length = max_length)
validation_dataloader    = DataLoader(tokenized_validation_set, batch_size=batch_size, shuffle=False)


  from .autonotebook import tqdm as notebook_tqdm


# Training

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).cuda()

# Set pad_token_id explicitly if needed
tokenizer.pad_token_id = tokenizer.eos_token_id

# Set up optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

for epoch_iteration in tqdm(range(1, num_epochs+1), desc="Training", total=num_epochs):
    for batch in training_dataloader:
        model = grpo_iteration(batch, model, compute_metrics, optimizer, num_samples, eps, beta, mu)

        # wandb.log({"epoch": epoch, "Train Accuracy": avg_train_acc})

    with torch.no_grad():
        for batch in validation_dataloader:
            outputs             = sample_outputs(model, batch, num_samples)
            rewards, accuracies = calculate_rewards_and_accuracies(batch, outputs, compute_metrics)
            wandb.log({"Rewards": torch.mean(rewards), "Accuracy": torch.mean(accuracies)})
            print(f"Epoch {epoch_iteration}: Accuracy: {torch.mean(accuracies)}")

wandb.finish()
