In [4]:
%reload_ext autoreload
%autoreload 2

In [5]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils.data import DataLoader
import logging

from dataset.countdown_dataloader import Countdown
from dataset.countdown_utils import ( gen_dataset, batch_compute_metrics )
from grpo import grpo_iteration 

In [6]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [7]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [8]:
# Create and save a really simple version of the countdown dataset
dataset_json_path = "simpler_countdown_data.json"
gen_dataset(num_samples=10, num_operands=2, max_target=10, max_number=10, save_path=dataset_json_path)

# Load the dataset
dataset = Countdown(dataset_json_path)

In [9]:
model_name = "Qwen/Qwen2.5-0.5B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

# Initialize the model with empty weights if needed
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model.to(device)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbe

In [None]:
# Batch out dataset
batch_size = 3
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
batch = next(iter(dataloader))

prompt_batch = batch["prompt"]

# Transform batch numbers and target into list of dictionaries
# This is slightly hacky, might look at instead reworking the reward model to deal with tensors
batch_numbers = list(map(list, zip(*batch["numbers"])))
batch_target = batch["target"]

raw_values_batch = [{'numbers': numbers, 'target': target} for numbers, target in zip(batch_numbers, batch_target)]
print("Batch raw:")
print(raw_values_batch)

Batch raw:
[{'numbers': [tensor(6), tensor(2)], 'target': tensor(4)}, {'numbers': [tensor(2), tensor(7)], 'target': tensor(9)}, {'numbers': [tensor(5), tensor(2)], 'target': tensor(7)}]


In [40]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

In [None]:
updated_policy = grpo_iteration(
    query_batch_prompts=prompt_batch,
    query_batch_raw=raw_values_batch,
    policy_model=model,
    reference_model=model,
    reward_model=batch_compute_metrics,
    tokenizer=tokenizer,
    optimizer=optimizer,
    G=3,
    eps=0.1,
    beta=0.05, 
    mu=3,
    max_new_tokens=64
)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
2025-04-07 15:48:24,411 - INFO - outputs: [[" Let's define a box to keep track of potential outcomes. We'll write helper functions to iterate over possible values and compute whether adding the current 6 and current 2 will yield 4. If this is true, we add the box identifier into our collection, return the first successful outcome, or return an empty string", " To create an equation that equals 4 using the numbers 6 and 2, we need to explore different combinations of these numbers. Let's solve it step by step:\n\n1. Start with the given numbers: 6 and 2.\n2. Consider the possible ways to combine these numbers or add them using basic", ' 62 + 2 + 4.'], [' The equation that equals 9 using the numbers [2, 7] is: 2 + 7 = 9.', ' Here it is:\n2 x 2 + 7 x 7 = 9.', ' To create an equation that equals 9 using the numbers [2, 7] as variables, we can use the fact that their squares are each 49 (since \\(2^2 = 4\\) and \\(7^2 