In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import os
import sys
import logging
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Setup module path for local imports
project_root = os.path.abspath(os.path.join(".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Importing the necessary modules
from grpo import grpo_iteration, evaluate_policy, sample_outputs, compute_log_probs, compute_log_probs_2
from dataset.countdown_utils import batch_compute_metrics
from dataset.countdown_dataloader import *

In [4]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
def collate_fn(batch):
    prompts = [item["prompt"] for item in batch]
    targets = torch.tensor([item["target"] for item in batch])
    numbers = [torch.tensor(item["numbers"]) for item in batch]
    # Pad the numbers to the same length
    padded_numbers = pad_sequence(
        numbers, batch_first=True, padding_value=0
    )  # Pad with zeros
    return {
        "prompt": prompts,
        "target": targets,
        "numbers": padded_numbers,
    }

# Load dataset
dataset = Countdown(
    json_path="../data/small-scale/countdown.json",
    model_type="instruct"
)
data_loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=collate_fn
)

In [7]:
model_name = "Qwen/Qwen2.5-0.5B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

# Initialize the model with empty weights if needed
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model.to(device)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbe

In [8]:
G=2
max_new_tokens=256
temperature=1.0

In [9]:
batch = next(iter(data_loader))

output_ids, generated_ids, outputs = sample_outputs(
    model,
    tokenizer,
    batch["prompt"],
    G,
    max_new_tokens,
    temperature,
)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


In [10]:
print("Output IDs:", output_ids.shape)
print("Generated IDs:", generated_ids.shape)
print("Outputs:", outputs)

Output IDs: torch.Size([2, 2, 391])
Generated IDs: torch.Size([2, 2, 256])
Outputs: [["  First, we'll assign the numbers to their corresponding operators. So now here we are:\n( 1 + 4 ) / 24</think> </think>\nSo first we calculate 1 + 4 = 5, and then divide 5 by 24, which gives us 5/24</think> </think>\nNext, we'll calculate 5/24 times 1/24 (which is basically the reciprocal of 24), so we get 5/576</think> </think>\nFinally, we'll divide 5/576 by 24. This gives us 1/2768</think> </think>\nNow, we should replace the numbers back in the order of the given list:\n( 1 + 4 ) / 24 * ( 1 / 24 + 5 / 576 ) / 2\nWe already calculated ( 1 + 4 ) and ( 1 / 24 + 5 / 576 ) for us, and the result is 2  <answer> (1 + 4 ) / 24 * ( 1 / 24 + 5", 'First, I will define the equation to be 21. </think>\nThink:\n\t4 + 24 - 1 = 21\n<br>\nHere, I used the numbers 4 and 24 and then did the addition, subtraction, and division step by step in the equation.\nThink:\n\t4 + 1 + 24 - 1 = 21\n<br>\nHere, I used the same

In [11]:
log_probs = compute_log_probs(
    policy=model,
    tokenizer=tokenizer,
    query_batch=batch["prompt"],
    generated_ids=generated_ids,
    temperature=temperature,
)

2025-04-16 18:58:55,469 - INFO - Attention mask shape: torch.Size([4, 391])
2025-04-16 18:58:55,469 - INFO - Amount of tokens: 1133
2025-04-16 18:59:10,587 - INFO - Generated logits shape: torch.Size([4, 255, 151936])
2025-04-16 18:59:11,053 - INFO - Generated IDs shape: torch.Size([2, 2, 255])


In [12]:
print("Log Probs:", log_probs.shape)
print("Log Probs:", log_probs)

Log Probs: torch.Size([2, 2, 255])
Log Probs: tensor([[[ -5.4844, -14.3594, -13.4062,  ..., -14.0781, -11.0391,  -8.8125],
         [-19.2188, -13.3516, -10.2422,  ..., -12.2109, -10.3359, -13.1562]],

        [[ -6.1719,  -7.1328,  -7.7930,  ..., -17.6719, -17.3125, -17.3125],
         [ -6.9883,  -9.0234, -10.8594,  ..., -16.4062, -16.7656, -16.2656]]],
       dtype=torch.float16, grad_fn=<ViewBackward0>)


In [15]:
log_probs_2 = compute_log_probs_2(
    policy=model,
    tokenizer=tokenizer,
    query_batch=batch["prompt"],
    output_ids=output_ids,
    generated_ids=generated_ids,
    temperature=temperature,
)

2025-04-16 19:02:36,551 - INFO - Output IDs shape: torch.Size([2, 2, 391])
2025-04-16 19:02:36,553 - INFO - Generated IDs shape: torch.Size([2, 2, 256])
2025-04-16 19:02:36,562 - INFO - Output IDs reshaped: torch.Size([4, 391])
2025-04-16 19:02:52,390 - INFO - Logits shape: torch.Size([4, 390, 151936])
2025-04-16 19:02:52,484 - INFO - Output IDs shape: torch.Size([4, 390])
2025-04-16 19:02:53,347 - INFO - Log probs shape: torch.Size([4, 390])
2025-04-16 19:02:53,355 - INFO - Query length: 134
2025-04-16 19:02:53,392 - INFO - Log probs reshaped: torch.Size([2, 2, 256])


In [16]:
print("Log Probs 2:", log_probs_2.shape)
print("Log Probs 2:", log_probs_2)

Log Probs 2: torch.Size([2, 2, 256])
Log Probs 2: tensor([[[-2.3008e+00, -6.6875e+00, -1.0992e-01,  ..., -2.9259e-03,
          -9.7609e-04, -3.8986e-03],
         [-2.3164e+00, -1.0205e-01, -1.1035e+00,  ..., -2.9831e-02,
          -9.7609e-04, -2.1255e-02]],

        [[-2.2812e+00, -5.9062e+00, -2.6641e+00,  ..., -1.7672e+01,
          -1.7312e+01, -1.7312e+01],
         [-3.7344e+00, -5.1392e-02, -2.6587e-01,  ..., -1.6406e+01,
          -1.6766e+01, -1.6266e+01]]], dtype=torch.float16,
       grad_fn=<ReshapeAliasBackward0>)
