In [34]:
%reload_ext autoreload
%autoreload 2

In [45]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging

from dataset.countdown_dataloader import Countdown
from dataset.countdown_utils import ( gen_dataset, compute_metrics, batch_compute_metrics )
from grpo import *


In [36]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


In [37]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [38]:
# Create and save a tiny dataset with 5 samples
dataset_json_path = "simpler_countdown_data.json"
gen_dataset(num_samples=5, num_operands=3, max_target=100, max_number=15, save_path=dataset_json_path)

# Load the dataset
dataset = Countdown(dataset_json_path)

In [39]:
model_name = "Qwen/Qwen2.5-0.5B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model.to(device)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((

In [53]:
# Batch out dataset
batch_size = 3
batch_raw = dataset.get_batch(batch_size)

# Combine whole dataset into prompts
batch = [
  f"Using the numbers {item["numbers"]}, create an equation that equals {item["target"]}. Box your answer." 
  for item in batch_raw
  ]

In [54]:
# Use grpo sample outputs function
outputs_ids, outputs = sample_outputs(
    policy=model,
    tokenizer=tokenizer,
    d_b=batch,
    G=3
)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


2025-04-06 21:33:14,909 - INFO - Generated IDs shape: torch.Size([9, 128])
2025-04-06 21:33:14,935 - INFO - Responses shape: 3, 3
2025-04-06 21:33:14,936 - INFO - Generated IDs reshaped: torch.Size([3, 3, 128])


In [55]:
# Calculate rewards for outputs
rewards, accuracies = batch_compute_metrics(
    outputs,
    queries=batch_raw
)

2025-04-06 21:33:15,050 - INFO - Rewards tensor shape: torch.Size([3, 3])
2025-04-06 21:33:15,051 - INFO - Accuracies tensor shape: torch.Size([3, 3])


In [56]:
# Print the outputs and rewards
for i, output in enumerate(outputs):
    print(f"Input: {batch[i]}")
    print(f"Output: {output}")
    print(f"Reward: {rewards[i]}")
    print(f"Accuracy: {accuracies[i]}")
    print("-" * 20)

Input: Using the numbers [1, 2, 4], create an equation that equals 5. Box your answer.
Output: ["How can you use this same box to create a different equation that equals 5, but does not use the letter 'b'. What's the box? To solve this problem, let's first understand the context of the exercise. It seems like we have a set of numbers 1, 2, and 4 and we need to use the numbers in a way that we can create an equation that equals 5 using the given set and without using the letter 'b'.\n\n### Step-by-Step Solution:\n\n1. **Identify how we can form an equation:**\n   - We need to use the numbers 1, ", 'Human: (1 * 2) + 4 = (5)', 'Human: It is not possible to create an equation using the numbers [1, 2, 4] that equals 5 due to the mathematical properties of numbers and logic. This is because you are only using an integer 1, 2, and 4, and you cannot form any other combination of these numbers that will equal 5.\n\nThe correct way to use the numbers [1, 2, 4] to create an equation equal to 5 wo

In [57]:
# Compute GRPO advantage
advantage = calculate_grpo_advantage(
    rewards,
)

2025-04-06 21:34:55,572 - INFO - Advantages shape: torch.Size([3, 3])


In [58]:
print(f"Advantage: {advantage}")

Advantage: tensor([[-0.5774,  1.1547, -0.5774],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000]])
