In [34]:
%reload_ext autoreload
%autoreload 2

In [35]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging

from dataset.countdown_dataloader import Countdown
from dataset.countdown_utils import ( gen_dataset, compute_metrics )
from grpo import *


In [36]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


In [37]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [38]:
# Create and save a tiny dataset with 5 samples
dataset_json_path = "simpler_countdown_data.json"
gen_dataset(num_samples=5, num_operands=3, max_target=100, max_number=15, save_path=dataset_json_path)

# Load the dataset
dataset = Countdown(dataset_json_path)

In [39]:
model_name = "Qwen/Qwen2.5-0.5B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model.to(device)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((

In [40]:
# Combine whole dataset into prompts
batch = [
  f"Using the numbers {item["numbers"]}, create an equation that equals {item["target"]}. Box your answer." 
  for item in dataset
  ]

smaller_batch = batch[:2]

In [None]:
# Use grpo sample outputs function
outputs, logprobs = sample_outputs(
    policy=model,
    tokenizer=tokenizer,
    d_b=smaller_batch,
    G=3
)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
2025-04-06 18:48:39,486 - INFO - Responses shape: 2, 3
2025-04-06 18:48:39,491 - INFO - Log probabilities shape: torch.Size([2, 3, 128])


[['Human: It is incorrect to assume that a linear combination of two numbers will always equal another number. For example, using the numbers {2,5} we can create\n2 * 4 + 5 = 14\n2 + 5 * 4 = 26.\nSo if we were looking for a solution to the equation 2 + x = 5, this is the answer.', 'Human: The sum of the two middle numbers equals 3 plus the quotient of the three numbers 1 divided by 2.\nTherefore, the answer is: \\( 3 = 1 + \\frac{4}{2} \\).', "(Do not use exponents, square roots and parentheses. Use only addition, subtraction, multiplication, division and parentheses.)\n\nWhat's the lowest possible value for the second number?\n\nWhat's the highest possible value for the second number? (Note: your code must work for arbitrary integers).\n\nWhat's the largest possible value for the second number?\n\nThe following number should not be the answer: 2100 or 45254"], [" How can you use multiple numbers in a single equation to reach a target sum? To create an equation using the numbers [14, 6