In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging

from dataset.countdown_dataloader import Countdown
from dataset.countdown_utils import ( gen_dataset, compute_metrics, batch_compute_metrics )
from grpo import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


In [4]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [5]:
# Create and save a tiny dataset with 5 samples
dataset_json_path = "simpler_countdown_data.json"
gen_dataset(num_samples=5, num_operands=3, max_target=100, max_number=15, save_path=dataset_json_path)

# Load the dataset
dataset = Countdown(dataset_json_path)

In [6]:
model_name = "Qwen/Qwen2.5-0.5B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id

# Initialize the model with empty weights if needed
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model.to(device)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbe

In [7]:
# Batch out dataset
batch_size = 3
batch_raw = dataset.get_batch(batch_size)

# Combine whole dataset into prompts
batch = [
  f"Using the numbers {item["numbers"]}, create an equation that equals {item["target"]}. Box your answer." 
  for item in batch_raw
  ]

In [8]:
# Use grpo sample outputs function
outputs_ids, outputs = sample_outputs(
    policy=model,
    tokenizer=tokenizer,
    query_batch=batch,
    G=3
)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
2025-04-07 11:13:20,987 - INFO - Generated IDs shape: torch.Size([9, 512])
2025-04-07 11:13:20,999 - INFO - Responses shape: 3, 3
2025-04-07 11:13:21,000 - INFO - Generated IDs reshaped: torch.Size([3, 3, 512])


In [9]:
# Calculate rewards for outputs
rewards, accuracies = batch_compute_metrics(
    outputs,
    queries=batch_raw
)

2025-04-07 11:13:21,121 - INFO - Rewards tensor shape: torch.Size([3, 3])
2025-04-07 11:13:21,122 - INFO - Accuracies tensor shape: torch.Size([3, 3])


In [10]:
# Print the outputs and rewards
for i, output in enumerate(outputs):
    print(f"Input: {batch[i]}")
    print(f"Output: {output}")
    print(f"Reward: {rewards[i]}")
    print(f"Accuracy: {accuracies[i]}")
    print("-" * 20)

Input: Using the numbers [1, 2, 4], create an equation that equals 5. Box your answer.
Output: ['Human beings should not be asked to guess.  \n[1, 1, 2, 2, 2, 4, 4, 4]\n[1, 4, 4, 4, 4, 4, 4, 4]\n[2, 1, 2, 1, 1, 4, 4, 4]\n[1, 1, 4, 2, 2, 4, 4, 4]', "If you haven't already, don't change your calculator to scientific mode.\nTo create an equation that equals 5 using the numbers [1, 2, 4] and the rules of solving for an unknown variable, let's denote the unknown variable as \\( x \\).\n\nHere's the equation you could create:\n\\[ x + 4 - 2 + 1 = 5 \\]\n\nLet's break it down step-by-step:\n\n1. Start with the sum of the given numbers and subtract 2.\n2. Add 1 and divide by 4 to isolate \\( x \\).\n\nNow, let's verify this by plugging in the numbers to see if we get 5.\n\n### Verification\n\\[ x + 4 - 2 + 1 = 5 \\]\n\\[ x + 3 = 5 \\]\n\\[ x = 2 \\]\n\nThe solution \\( x = 2 \\) satisfies the equation, so the equation \\( x + 4 - 2 + 1 = 5 \\) works.\n\nSo, an equation that equals 5 using the 

In [11]:
# Compute GRPO advantage
advantage = calculate_grpo_advantage(
    rewards
)

2025-04-07 11:13:21,249 - INFO - Advantages shape: torch.Size([3, 3])


In [12]:
print(f"Advantage: {advantage}")

Advantage: tensor([[-1.1547,  0.5774,  0.5774],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000]])


In [13]:
# Compute log probabilities
log_probs = compute_log_probs(
    policy=model,
    tokenizer=tokenizer,
    query_batch=batch,
    generated_ids=outputs_ids
)

2025-04-07 11:13:21,298 - INFO - Query IDs shape: torch.Size([3, 3, 27])
2025-04-07 11:13:21,298 - INFO - Generated IDs shape: torch.Size([3, 3, 512])
2025-04-07 11:13:21,299 - INFO - Input IDs shape: torch.Size([3, 3, 539])
2025-04-07 11:13:21,299 - INFO - Reshaped Input IDs shape: torch.Size([9, 539])
2025-04-07 11:13:21,299 - INFO - Attention mask shape: torch.Size([9, 539])
2025-04-07 11:14:34,465 - INFO - Logits shape: torch.Size([9, 539, 151936])
2025-04-07 11:14:34,742 - INFO - Generated logits shape: torch.Size([9, 512, 151936])
2025-04-07 11:14:38,196 - INFO - Log probabilities shape: torch.Size([9, 512, 151936])
2025-04-07 11:14:38,423 - INFO - Gathered log probabilities shape: torch.Size([9, 512])
2025-04-07 11:14:38,446 - INFO - Reshaped log probabilities shape: torch.Size([3, 3, 512])


In [14]:
# Calculate GRPO objective
grpo_objective = calculate_grpo_objective(
    model_log_probs=log_probs,
    old_model_log_probs=log_probs,  # Assuming old model is the same for this example
    ref_model_log_probs=log_probs,  # Assuming ref model is the same for this example
    advantages=advantage,
    eps=0.1,  # Epsilon for clipping
    beta=0.05,  # Beta for the objective function
)

print(f"GRPO Objective: {grpo_objective}")

2025-04-07 11:14:41,787 - INFO - Prob ratios shape: torch.Size([3, 3, 512])
2025-04-07 11:14:41,823 - INFO - Clipped ratios shape: torch.Size([3, 3, 512])
2025-04-07 11:14:41,823 - INFO - Advantages shape: torch.Size([3, 3, 1])
2025-04-07 11:14:41,885 - INFO - Min product shape: torch.Size([3, 3, 512])
2025-04-07 11:14:41,889 - INFO - KL divergence shape: torch.Size([3, 3, 512])
2025-04-07 11:14:41,892 - INFO - Objective shape: torch.Size([3, 3, 512])
2025-04-07 11:14:41,933 - INFO - Final objective shape: torch.Size([3])


GRPO Objective: tensor([-7.9473e-08,  0.0000e+00,  0.0000e+00], grad_fn=<MeanBackward1>)


In [15]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

In [17]:
updated_policy = grpo_iteration(
    query_batch_prompts=batch,
    query_batch_raw=batch_raw,
    policy_model=model,
    reference_model=model,  # Assuming reference model is the same for this example
    reward_model=batch_compute_metrics,
    tokenizer=tokenizer,
    optimizer=optimizer,
    G=3,
    eps=0.1,  # Epsilon for clipping
    beta=0.05,  # Beta for the objective function
    mu=3
)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
2025-04-07 11:17:34,076 - INFO - Generated IDs shape: torch.Size([9, 512])
2025-04-07 11:17:34,114 - INFO - Responses shape: 3, 3
2025-04-07 11:17:34,114 - INFO - Generated IDs reshaped: torch.Size([3, 3, 512])
2025-04-07 11:17:34,121 - INFO - Rewards tensor shape: torch.Size([3, 3])
2025-04-07 11:17:34,122 - INFO - Accuracies tensor shape: torch.Size([3, 3])
2025-04-07 11:17:34,124 - INFO - Advantages shape: torch.Size([3, 3])
2025-04-07 11:17:34,133 - INFO - Query IDs shape: torch.Size([3, 3, 27])
2025-04-07 11:17:34,134 - INFO - Generated IDs shape: torch.Size([3, 3, 512])
2025-04-07 11:17:34,134 - INFO - Input IDs shape: torch.Size([3, 3, 539])
2025-04-07 11:17:34,134 - INFO - Reshaped Input IDs shape: torch.Size([9, 539])
2025-04-07 11:17:

RuntimeError: grad can be implicitly created only for scalar outputs