In [1]:
%load_ext autoreload
%autoreload 2
import re 
import os 
os.environ["TOKENIZERS_PARALLELIS"] = "true"
import torch 
from torch.utils.data import DataLoader 
import sympy 

import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers import PreTrainedTokenizer, PreTrainedModel
import numpy as np
from memory_profiler import memory_usage

import datasets 
from datasets import load_dataset 

from curious.data import ReasoningGymDataset, GSM8KDataset
from curious.utils import tokenize_questions, load_model_tokenizer
from curious.grpo import rollout, sequences_log_probs, sequences_log_probs_with_mask
from curious.reward import GSM8KRewardModel

import reasoning_gym 

from math_verify import verify, parse
from pprint import pprint

EACH_DATASET_SIZE = 100
SEED = 42
MODEL_NAME = "Qwen/Qwen2-0.5B-Instruct"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer:PreTrainedTokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model:PreTrainedModel = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer(["Today is"], return_tensors="pt")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [3]:
# Example 1: Print the scores for each token generated with Greedy Search
outputs = model.generate(
    **inputs, 
    max_new_tokens=10, 
    return_dict_in_generate=True, 
    output_logits=True, 
    do_sample=True,
    num_return_sequences=4,
)
print(type(outputs))
print(outputs.sequences.shape)
print(len(outputs.logits))
for new_token in outputs.logits:
    print(new_token.shape)

<class 'transformers.generation.utils.GenerateDecoderOnlyOutput'>
torch.Size([4, 12])
10
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])


In [4]:
masked_logprobs = sequences_log_probs_with_mask(
    model, 
    tokenizer, 
    outputs, 
)

In [5]:
transition_logprobs = model.compute_transition_scores(
    outputs.sequences, outputs.logits, normalize_logits=True
)
print(transition_logprobs.shape)
# print(outputs.scores)

torch.Size([4, 10])


In [6]:
# input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for
# encoder-decoder models, like BART or T5.
input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
print(input_length)
generated_tokens = outputs.sequences[:, input_length:]
print(generated_tokens.shape)
for tok, score in zip(generated_tokens[0], transition_logprobs[0]):
    # | token | token string | log probability | probability
    print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")

2
torch.Size([4, 10])
|   279 |  the     | -1.314 | 26.87%
|  1156 |  first   | -1.747 | 17.44%
|  1899 |  day     | -0.348 | 70.62%
|   315 |  of      | -0.157 | 85.46%
|   847 |  my      | -2.964 | 5.16%
|   501 |  new     | -1.947 | 14.26%
|  2618 |  job     | -1.123 | 32.55%
|   323 |  and     | -2.536 | 7.92%
|   358 |  I       | -0.524 | 59.21%
|  1079 |  am      | -1.176 | 30.85%


In [17]:
log_probs = sequences_log_probs(model, outputs.sequences, inputs.attention_mask)
print(log_probs.shape)

torch.Size([4, 11])


In [8]:
out_seqs = outputs.sequences[:,1:]
in_seqs = outputs.sequences[:,:-1]

for i in range(log_probs.shape[1]):

    in_token = in_seqs[0][i]
    out_token = out_seqs[0][i]

    print( tokenizer.decode(in_token), "->" , tokenizer.decode(out_token), f"({log_probs[0][i].item()})")

Today ->  is (-4.035375595092773)
 is ->  the (-1.3148212432861328)
 the ->  first (-1.7471065521240234)
 first ->  day (-0.34832000732421875)
 day ->  of (-0.15752410888671875)
 of ->  my (-2.96502685546875)
 my ->  new (-1.947946548461914)
 new ->  job (-1.1230058670043945)
 job ->  and (-2.5361595153808594)
 and ->  I (-0.5244293212890625)
 I ->  am (-1.1761493682861328)


In [46]:
model, tokenizer = load_model_tokenizer(MODEL_NAME)

ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.

In [21]:
gsm8k = GSM8KDataset()
gsm8k_rm = GSM8KRewardModel(use_format_reward=False, use_strict_format_reward=False)

In [22]:
idx = 8
item = gsm8k[idx]
pprint(item)

{'answer': 'He has 36 eggs because 3 x 12 = <<3*12=36>>36\n'
           'He can make 9 omelets because 36 / 4 = <<36/4=9>>9\n'
           'Each person gets 3 omelets because 9 / 3 = <<9/3=3>>3\n'
           '#### 3',
 'oracle_answer': '3',
 'question': 'Pauly is making omelets for his family. There are three dozen '
             'eggs, and he plans to use them all. Each omelet requires 4 eggs. '
             'Including himself, there are 3 people. How many omelets does '
             'each person get?'}


In [23]:
gsm8k_rm.answer_pattern

'<answer>(.*?)</answer>'

In [24]:
completions = gsm8k[:10]["answer"]
completions = [
    completion.replace("####", "<answer>").strip() + " </answer>" \
    for completion in completions
] 
oracle_answers = gsm8k[:10]["oracle_answer"]

In [25]:
gsm8k_rm.outcome_reward(completions[0], oracle_answers[0])

Mimi has 2 x 12 = <<2*12=24>>24 sea shells.
Kyle has 24 x 2 = <<24*2=48>>48 sea shells.
Leigh has 48 / 3 = <<48/3=16>>16 sea shells.
<answer> 16 </answer>
[' 16 ']
<answer>(.*?)</answer>


([16, '16'], 1.0, {'outcome': None})

In [26]:
print(completions[0])
print(oracle_answers[0])

Mimi has 2 x 12 = <<2*12=24>>24 sea shells.
Kyle has 24 x 2 = <<24*2=48>>48 sea shells.
Leigh has 48 / 3 = <<48/3=16>>16 sea shells.
<answer> 16 </answer>
16


In [29]:
rewards, infos , solved_rate = gsm8k_rm(completions, oracle_answers)

In [30]:
print(solved_rate)

1.0


In [103]:
from tqdm import tqdm
for i in tqdm(range(len(gsm8k))):
    item = gsm8k[i]
    completion = item["answer"]
    completion = completion.replace("####", "<answer>").strip()
    completion += "</answer>"
    #print(completion)

    parsed_answer, outcome_reward, outcome_info = gsm8k_rm.outcome_reward(completion, item["oracle_answer"])
    #print(parsed_answer, reward, outcome_info)


    parsed_reasoning, format_reward, format_info = gsm8k_rm.format_reward(completion)
    #print(parsed_reasoning, reward, format_info)
    assert format_reward == 0.0, print(format_info)


    reward, info = gsm8k_rm.instance_reward(completion, item["oracle_answer"])
    #print(reward, info)
    assert reward == 1.0, print(info)

100%|██████████| 7473/7473 [00:07<00:00, 945.25it/s]


In [75]:
print(type(parsed_answer[0]))

<class 'sympy.core.numbers.Integer'>


In [78]:
isinstance(parsed_answer[0], sympy.Expr)

True

In [65]:
item["oracle_answer"]

['3', '55']

In [64]:
verify(parsed_answer, item["oracle_answer"])

False

In [40]:
text = """<think> 
because the number is too big, we need to use scientific notation
</think>
"""
think_match = re.search(r"<think>(.*?)</think>", text, flags=re.DOTALL)
print(think_match.group(1))
parsed_answer = parse(text, fallback_mode="first_match", extraction_mode="first_match")
print(parsed_answer)
print(type(parsed_answer[0]), type(parsed_answer[1]))

 
because the number is too big, we need to use scientific notation

[]


IndexError: list index out of range

# Load the dataset
*** 

In [91]:
datasets = [
    "complex_arithmetic",
    "intermediate_integration",
    "polynomial_equations",
    "simple_equations",
]

dataset = ReasoningGymDataset(
    datasets_name=datasets,
    size=EACH_DATASET_SIZE,
    seed=SEED
)

print(len(dataset))


400


In [92]:
dataset[0]

{'question': 'Add the complex numbers: (-10.0 - 2.0i) + (-3.0 - 3.0i)',
 'answer': '-13.0 - 5.0i',
 'dataset_name': 'complex_arithmetic',
 'dataset_idx': 0,
 'item_idx': 0,
 'global_idx': 0}

In [93]:
dataset[9]

{'question': 'Subtract the complex numbers: (6.0 + 7.0i) - (-5.0 - 3.0i)',
 'answer': '11.0 + 10.0i',
 'dataset_name': 'complex_arithmetic',
 'dataset_idx': 0,
 'item_idx': 9,
 'global_idx': 9}

In [94]:
data_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0
)
print(len(data_loader))    


13


In [53]:
for batch in data_loader:
    print(len(batch["question"]))
    print(batch["dataset_name"])
    print('*'*100)

32
['complex_arithmetic', 'intermediate_integration', 'polynomial_equations', 'polynomial_equations', 'intermediate_integration', 'simple_equations', 'polynomial_equations', 'polynomial_equations', 'simple_equations', 'simple_equations', 'simple_equations', 'polynomial_equations', 'complex_arithmetic', 'polynomial_equations', 'polynomial_equations', 'simple_equations', 'simple_equations', 'simple_equations', 'polynomial_equations', 'simple_equations', 'intermediate_integration', 'complex_arithmetic', 'intermediate_integration', 'simple_equations', 'intermediate_integration', 'polynomial_equations', 'polynomial_equations', 'simple_equations', 'polynomial_equations', 'complex_arithmetic', 'intermediate_integration', 'intermediate_integration']
****************************************************************************************************
32
['intermediate_integration', 'complex_arithmetic', 'complex_arithmetic', 'polynomial_equations', 'simple_equations', 'simple_equations', 'polyno

# Load the model 
*** 

In [54]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [55]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, 
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    #attn_implementation="flash_attention_2",
)

# Tokenization 
*** 

In [56]:
samples = next(iter(data_loader))

In [83]:
encodings = tokenize_questions(
    tokenizer,
    questions= samples["question"],
    max_length= 1024,
)

Adjusting padding side from right to left for training


In [84]:
encodings["attention_mask"].shape 

torch.Size([32, 1024])

In [85]:
tokenizer.padding_side

'left'

In [86]:
print(
    tokenizer.decode(encodings["input_ids"][0])
)


<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|

In [76]:
tokenizer.special_tokens_map

{'eos_token': '<|im_end|>',
 'pad_token': '<|endoftext|>',
 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}

In [79]:
tokenizer.all_special_ids

[151645, 151643, 151644]

In [81]:
tokenizer.convert_ids_to_tokens(tokenizer.all_special_ids)

['<|im_end|>', '<|endoftext|>', '<|im_start|>']

# Policy Rollout 
*** 

In [95]:
generation_config = GenerationConfig(
    max_new_tokens=512,
    do_sample=True,
    top_p=0.9,
    top_k=50,
    temperature=0.7,
    num_return_sequences=4,
)

In [99]:
seq_ids, returns, action_mask, completions = rollout(
    model,
    tokenizer,
    batch_inputs=encodings,
    oracle_answers=samples["answer"],
    generation_config=generation_config,
)



RuntimeError: Placeholder storage has not been allocated on MPS device!