In [15]:
%load_ext autoreload
%autoreload 2
import re 
import os 
os.environ["TOKENIZERS_PARALLELIS"] = "true"
import torch 
from torch.utils.data import DataLoader 
import sympy 

import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers import PreTrainedTokenizer, PreTrainedModel

import numpy as np
from memory_profiler import memory_usage

import datasets 
from datasets import load_dataset 

from curious.data import ReasoningGymDataset, GSM8KDataset
from curious.utils import tokenize_questions, load_model_tokenizer
from curious.sampling import (
    rollout, 
    sequences_log_probs, 
    sequences_log_probs_with_mask, 
    sample_responses,
    compute_rewards,
    compute_group_advantages,
)
from curious.buffer import Experience, ReplayBuffer, join_experience_batch
from curious.reward import GSM8KRewardModel, QWEN_ANSWER_PATTERN
from curious.prompt import *

from lightning import seed_everything
import reasoning_gym 

from math_verify import verify, parse
from pprint import pprint

MAX_PROMPT_LENGTH = 512
PER_DEVICE_BATCH_SIZE = 2
EACH_DATASET_SIZE = 100
SEED = 42
GROUP_SIZE = 2
MODEL_NAME = "Qwen/Qwen2-0.5B-Instruct"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Loading 
*** 

In [None]:
tokenizer:PreTrainedTokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model:PreTrainedModel = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer.pad_token_id = tokenizer.eos_token_id


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


# Compare implementation 
*** 

# Dataset 

In [3]:
dataset = GSM8KDataset(
    tokenizer=tokenizer,
    dataset_name="openai/gsm8k",
    seed=SEED,
    mode="train",
    max_prompt_length=MAX_PROMPT_LENGTH,
)
train_dataset = dataset.train
test_dataset = dataset.test
print(dataset)

Detected train_max_length: 212
Setting train_max_length to 212


Map: 100%|██████████| 7473/7473 [00:01<00:00, 4576.88 examples/s]


Detected test_max_length: 188
Setting test_max_length to 188


Map: 100%|██████████| 1319/1319 [00:00<00:00, 2813.36 examples/s]

<curious.data.GSM8KDataset object at 0x7848661cb3d0>





In [4]:
train_dataset

Dataset({
    features: ['question', 'answer', 'oracle_answer', 'input_ids', 'attention_mask'],
    num_rows: 7473
})

In [5]:
dataset.train["input_ids"][0].shape, dataset.train["input_ids"][1].shape

(torch.Size([212]), torch.Size([212]))

## Rollout 
*** 

In [5]:
replay_buffer = ReplayBuffer()

In [6]:
start_idx = 0
x = dataset[start_idx: start_idx + PER_DEVICE_BATCH_SIZE]

In [7]:
dataset.train 

Dataset({
    features: ['question', 'answer', 'oracle_answer', 'input_ids', 'attention_mask'],
    num_rows: 7473
})

In [8]:
reward_model = GSM8KRewardModel(
    answer_pattern=QWEN_ANSWER_PATTERN,
    use_format_reward=False,
    use_strict_format_reward=False
)

rollout_out = rollout(
    model,
    tokenizer,
    batch_inputs=x,
    reward_model=reward_model,
    generation_config=GenerationConfig(
        max_new_tokens=125,
        do_sample=True,
        num_return_sequences=GROUP_SIZE,
        top_p=1.0,
        top_k=50,
        temperature=0.9,
    ),
    group_size=GROUP_SIZE,
    seed=SEED,
    normalize_centered_returns=True,
)

In [9]:
for k, v in rollout_out.items():
    if isinstance(v, torch.Tensor):
        print(k, "->", v.shape)
    else:
        print(k, "->", v)


returns -> torch.Size([2, 2])
solved_masks -> torch.Size([2, 2])
infos -> [[{'outcome': None, 'parsed_answer': [16, '16'], 'outcome_reward': 1.0, 'format_': None, 'parsed_reasoning': None, 'format_reward': 0.0}, {'outcome': 'no_answer_in_required_format', 'parsed_answer': None, 'outcome_reward': -1.0, 'format_': None, 'parsed_reasoning': None, 'format_reward': 0.0}], [{'outcome': 'no_answer_in_required_format', 'parsed_answer': None, 'outcome_reward': -1.0, 'format_': None, 'parsed_reasoning': None, 'format_reward': 0.0}, {'outcome': 'no_answer_in_required_format', 'parsed_answer': None, 'outcome_reward': -1.0, 'format_': None, 'parsed_reasoning': None, 'format_reward': 0.0}]]
num_samples -> 4
num_samples_per_group -> 2
sequence_ids -> torch.Size([4, 637])
action_mask -> torch.Size([4, 636])
advantages -> torch.Size([2, 2])
num_words_in_completions -> torch.Size([4])
completions -> [' Mimi picked up 2 dozen seashells. Since 1 dozen is equal to 12, Mimi picked up 2*12 = 24 seashells.\nK

In [10]:
print(rollout_out["advantages"])
print(rollout_out["solved_masks"])
print(rollout_out["returns"])

tensor([[ 0.7071, -0.7071],
        [ 0.0000,  0.0000]])
tensor([[1., 0.],
        [0., 0.]])
tensor([[ 1., -1.],
        [-1., -1.]])


In [11]:
pad_token_id = tokenizer.eos_token_id
attn_mask = rollout_out["sequence_ids"] != pad_token_id
log_probs = sequences_log_probs(model, rollout_out["sequence_ids"], attn_mask)
print(log_probs.shape)

torch.Size([4, 636])


In [27]:
sequence_ids = rollout_out["sequence_ids"]
action_mask = rollout_out["action_mask"]

returns = rollout_out["returns"].reshape(-1)
solved_mask = rollout_out["solved_masks"].reshape(-1)
advantages = rollout_out["advantages"].reshape(-1)

print(sequence_ids.shape)
print(returns.shape)
print(solved_mask.shape)
print(advantages.shape)


torch.Size([4, 637])
torch.Size([4])
torch.Size([4])
torch.Size([4])


In [28]:
for completion in rollout_out["completions"]:
    print(completion)
    print("-"*100)

 Mimi picked up 2 dozen seashells. Since 1 dozen is equal to 12, Mimi picked up 2*12 = 24 seashells.
Kyle found twice as many shells as Mimi, which means Kyle found 24*2 = 48 seashells.
Leigh grabbed one-third of the shells that Kyle found, which means Leigh grabbed 48/3 = 16 seashells.
So, Leigh had 16 seashells.
The answer is: $\boxed{16}$.
----------------------------------------------------------------------------------------------------
 First, let's find out how many seashells Mimi found. Since she picked up 2 dozen seashells, a dozen is 12. So, Mimi found 2 * 12 = 24 seashells.

Then, Kyle found twice as many shells as Mimi. So, Kyle found 24 * 2 = 48 seashells.

Finally, Leigh grabbed one-third of the shells that Kyle found. Since Kyle found 48 shells, Leigh grabbed 48 / 3 = 16 seashells.

So, Leigh had 1
----------------------------------------------------------------------------------------------------
Let's break down the information given to make it easier to solve:

1. Fra

In [29]:
experience = Experience(
    sequences=sequence_ids,
    action_log_probs=log_probs,
    log_probs_ref=log_probs,
    returns=returns,
    solved_mask=solved_mask,
    advantages=advantages,
    attention_mask=attn_mask,
    action_mask=action_mask,
)

In [30]:
print(experience.keys)

['sequences', 'action_log_probs', 'log_probs_ref', 'returns', 'solved_mask', 'advantages', 'attention_mask', 'action_mask']


In [31]:
replay_buffer.append(experience)

['sequences', 'action_log_probs', 'log_probs_ref', 'returns', 'solved_mask', 'advantages', 'attention_mask', 'action_mask']
0 dict_keys(['sequences', 'action_log_probs', 'log_probs_ref', 'returns', 'solved_mask', 'advantages', 'attention_mask', 'action_mask'])
1 dict_keys(['sequences', 'action_log_probs', 'log_probs_ref', 'returns', 'solved_mask', 'advantages', 'attention_mask', 'action_mask'])
2 dict_keys(['sequences', 'action_log_probs', 'log_probs_ref', 'returns', 'solved_mask', 'advantages', 'attention_mask', 'action_mask'])
3 dict_keys(['sequences', 'action_log_probs', 'log_probs_ref', 'returns', 'solved_mask', 'advantages', 'attention_mask', 'action_mask'])


In [37]:
exp = replay_buffer[1]
for k in exp.keys:
    print(k, getattr(exp, k).size())

sequences torch.Size([637])
action_log_probs torch.Size([636])
log_probs_ref torch.Size([636])
returns torch.Size([])
solved_mask torch.Size([])
advantages torch.Size([])
attention_mask torch.Size([637])
action_mask torch.Size([636])


In [52]:
e1 = replay_buffer[0]
e2 = replay_buffer[1]
e3 = replay_buffer[2]
e4 = replay_buffer[3]

joint_experience = join_experience_batch([e1, e2, e3, e4])
print(joint_experience)

Experience(sequences=tensor([[151645, 151645, 151645,  ..., 151643, 151643, 151643],
        [151645, 151645, 151645,  ...,   1030,    220,     16],
        [151645, 151645, 151645,  ...,    279,  25103,    374],
        [151645, 151645, 151645,  ...,    220,     18,     20]]), action_log_probs=tensor([[-8.4310e+00, -8.4310e+00, -8.4310e+00,  ..., -1.6949e+01,
         -1.6806e+01, -1.6656e+01],
        [-8.5017e+00, -8.5017e+00, -8.5017e+00,  ..., -1.4700e-01,
         -3.5477e-04,  0.0000e+00],
        [-8.6663e+00, -8.6663e+00, -8.6663e+00,  ..., -5.6907e-01,
         -1.2562e+00, -2.0771e+00],
        [-8.6917e+00, -8.6917e+00, -8.6917e+00,  ..., -5.7220e-06,
         -1.3485e-01, -6.0103e-01]], grad_fn=<StackBackward0>), attention_mask=tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True]]), action_mask=ten

In [54]:
for k in joint_experience.keys:
    assert getattr(joint_experience, k).size() == getattr(experience, k).size()
    assert torch.isclose(getattr(joint_experience, k), getattr(experience, k)).all()

# Logprobs computation 
*** 

In [9]:
inputs = tokenizer(["Today is a nice day"], return_tensors="pt")

In [10]:
# Example 1: Print the scores for each token generated with Greedy Search
outputs = model.generate(
    **inputs, 
    max_new_tokens=10, 
    return_dict_in_generate=True, 
    output_logits=True, 
    do_sample=True,
    num_return_sequences=4,
)
print(type(outputs))
print(outputs.sequences.shape)
print(len(outputs.logits))
for new_token in outputs.logits:
    print(new_token.shape)

<class 'transformers.generation.utils.GenerateDecoderOnlyOutput'>
torch.Size([4, 15])
10
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])
torch.Size([4, 151936])


In [11]:
transition_logprobs = model.compute_transition_scores(
    outputs.sequences, outputs.logits, normalize_logits=True
)
print(transition_logprobs.shape)
# print(outputs.scores)

torch.Size([4, 10])


In [12]:
# input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for
# encoder-decoder models, like BART or T5.
input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
print(input_length)
generated_tokens = outputs.sequences[:, input_length:]
print(generated_tokens.shape)
for tok, score in zip(generated_tokens[0], transition_logprobs[0]):
    # | token | token string | log probability | probability
    print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")

5
torch.Size([4, 10])
|   369 |  for     | -2.168 | 11.44%
|   264 |  a       | -1.541 | 21.42%
| 68883 |  stroll  | -3.770 | 2.31%
|  1526 |  through | -1.842 | 15.85%
|   279 |  the     | -0.433 | 64.84%
|  3283 |  city    | -2.946 | 5.26%
|    13 | .        | -1.126 | 32.42%
|   358 |  I       | -1.329 | 26.48%
|  3003 | 've      | -3.050 | 4.74%
|  2684 |  got     | -2.642 | 7.12%


In [14]:

attention_mask = outputs.sequences != tokenizer.pad_token_id
log_probs = sequences_log_probs(model, outputs.sequences, attention_mask)
print(log_probs.shape)

TorchRuntimeError: Failed running call_method view(*(FakeTensor(..., size=(s1 - 1, 151936), grad_fn=<SelectBackward0>), 1, -1, -1), **{}):
only one dimension can be inferred

from user code:
   File "/teamspace/studios/this_studio/curious/curious/sampling.py", line 248, in sequences_log_probs
    log_probs = slow_sequence_log_probs_from_logits(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.11/site-packages/torch/_dynamo/polyfills/__init__.py", line 160, in getattr_and_trace
    return fn(*args[2:], **kwargs)
  File "/teamspace/studios/this_studio/curious/curious/sampling.py", line 217, in slow_sequence_log_probs_from_logits
    logits=logits_row.view(1, -1, -1),

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


In [8]:
out_seqs = outputs.sequences[:,1:]
in_seqs = outputs.sequences[:,:-1]

for i in range(log_probs.shape[1]):

    in_token = in_seqs[0][i]
    out_token = out_seqs[0][i]

    print( tokenizer.decode(in_token), "->" , tokenizer.decode(out_token), f"({log_probs[0][i].item()})")

Today ->  is (-4.035375595092773)
 is ->  the (-1.3148212432861328)
 the ->  first (-1.7471065521240234)
 first ->  day (-0.34832000732421875)
 day ->  of (-0.15752410888671875)
 of ->  my (-2.96502685546875)
 my ->  new (-1.947946548461914)
 new ->  job (-1.1230058670043945)
 job ->  and (-2.5361595153808594)
 and ->  I (-0.5244293212890625)
 I ->  am (-1.1761493682861328)


In [46]:
model, tokenizer = load_model_tokenizer(MODEL_NAME)

ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.

In [21]:
gsm8k = GSM8KDataset()
gsm8k_rm = GSM8KRewardModel(use_format_reward=False, use_strict_format_reward=False)

In [22]:
idx = 8
item = gsm8k[idx]
pprint(item)

{'answer': 'He has 36 eggs because 3 x 12 = <<3*12=36>>36\n'
           'He can make 9 omelets because 36 / 4 = <<36/4=9>>9\n'
           'Each person gets 3 omelets because 9 / 3 = <<9/3=3>>3\n'
           '#### 3',
 'oracle_answer': '3',
 'question': 'Pauly is making omelets for his family. There are three dozen '
             'eggs, and he plans to use them all. Each omelet requires 4 eggs. '
             'Including himself, there are 3 people. How many omelets does '
             'each person get?'}


In [23]:
gsm8k_rm.answer_pattern

'<answer>(.*?)</answer>'

In [24]:
completions = gsm8k[:10]["answer"]
completions = [
    completion.replace("####", "<answer>").strip() + " </answer>" \
    for completion in completions
] 
oracle_answers = gsm8k[:10]["oracle_answer"]

In [25]:
gsm8k_rm.outcome_reward(completions[0], oracle_answers[0])

Mimi has 2 x 12 = <<2*12=24>>24 sea shells.
Kyle has 24 x 2 = <<24*2=48>>48 sea shells.
Leigh has 48 / 3 = <<48/3=16>>16 sea shells.
<answer> 16 </answer>
[' 16 ']
<answer>(.*?)</answer>


([16, '16'], 1.0, {'outcome': None})

In [26]:
print(completions[0])
print(oracle_answers[0])

Mimi has 2 x 12 = <<2*12=24>>24 sea shells.
Kyle has 24 x 2 = <<24*2=48>>48 sea shells.
Leigh has 48 / 3 = <<48/3=16>>16 sea shells.
<answer> 16 </answer>
16


In [29]:
rewards, infos , solved_rate = gsm8k_rm(completions, oracle_answers)

In [30]:
print(solved_rate)

1.0


In [103]:
from tqdm import tqdm
for i in tqdm(range(len(gsm8k))):
    item = gsm8k[i]
    completion = item["answer"]
    completion = completion.replace("####", "<answer>").strip()
    completion += "</answer>"
    #print(completion)

    parsed_answer, outcome_reward, outcome_info = gsm8k_rm.outcome_reward(completion, item["oracle_answer"])
    #print(parsed_answer, reward, outcome_info)


    parsed_reasoning, format_reward, format_info = gsm8k_rm.format_reward(completion)
    #print(parsed_reasoning, reward, format_info)
    assert format_reward == 0.0, print(format_info)


    reward, info = gsm8k_rm.instance_reward(completion, item["oracle_answer"])
    #print(reward, info)
    assert reward == 1.0, print(info)

100%|██████████| 7473/7473 [00:07<00:00, 945.25it/s]


In [75]:
print(type(parsed_answer[0]))

<class 'sympy.core.numbers.Integer'>


In [78]:
isinstance(parsed_answer[0], sympy.Expr)

True

In [65]:
item["oracle_answer"]

['3', '55']

In [64]:
verify(parsed_answer, item["oracle_answer"])

False

In [40]:
text = """<think> 
because the number is too big, we need to use scientific notation
</think>
"""
think_match = re.search(r"<think>(.*?)</think>", text, flags=re.DOTALL)
print(think_match.group(1))
parsed_answer = parse(text, fallback_mode="first_match", extraction_mode="first_match")
print(parsed_answer)
print(type(parsed_answer[0]), type(parsed_answer[1]))

 
because the number is too big, we need to use scientific notation

[]


IndexError: list index out of range

# Load the dataset
*** 

In [91]:
datasets = [
    "complex_arithmetic",
    "intermediate_integration",
    "polynomial_equations",
    "simple_equations",
]

dataset = ReasoningGymDataset(
    datasets_name=datasets,
    size=EACH_DATASET_SIZE,
    seed=SEED
)

print(len(dataset))


400


In [92]:
dataset[0]

{'question': 'Add the complex numbers: (-10.0 - 2.0i) + (-3.0 - 3.0i)',
 'answer': '-13.0 - 5.0i',
 'dataset_name': 'complex_arithmetic',
 'dataset_idx': 0,
 'item_idx': 0,
 'global_idx': 0}

In [93]:
dataset[9]

{'question': 'Subtract the complex numbers: (6.0 + 7.0i) - (-5.0 - 3.0i)',
 'answer': '11.0 + 10.0i',
 'dataset_name': 'complex_arithmetic',
 'dataset_idx': 0,
 'item_idx': 9,
 'global_idx': 9}

In [94]:
data_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0
)
print(len(data_loader))    


13


In [53]:
for batch in data_loader:
    print(len(batch["question"]))
    print(batch["dataset_name"])
    print('*'*100)

32
['complex_arithmetic', 'intermediate_integration', 'polynomial_equations', 'polynomial_equations', 'intermediate_integration', 'simple_equations', 'polynomial_equations', 'polynomial_equations', 'simple_equations', 'simple_equations', 'simple_equations', 'polynomial_equations', 'complex_arithmetic', 'polynomial_equations', 'polynomial_equations', 'simple_equations', 'simple_equations', 'simple_equations', 'polynomial_equations', 'simple_equations', 'intermediate_integration', 'complex_arithmetic', 'intermediate_integration', 'simple_equations', 'intermediate_integration', 'polynomial_equations', 'polynomial_equations', 'simple_equations', 'polynomial_equations', 'complex_arithmetic', 'intermediate_integration', 'intermediate_integration']
****************************************************************************************************
32
['intermediate_integration', 'complex_arithmetic', 'complex_arithmetic', 'polynomial_equations', 'simple_equations', 'simple_equations', 'polyno

# Load the model 
*** 

In [54]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [55]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, 
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    #attn_implementation="flash_attention_2",
)

# Tokenization 
*** 

In [56]:
samples = next(iter(data_loader))

In [83]:
encodings = tokenize_questions(
    tokenizer,
    questions= samples["question"],
    max_length= 1024,
)

Adjusting padding side from right to left for training


In [84]:
encodings["attention_mask"].shape 

torch.Size([32, 1024])

In [85]:
tokenizer.padding_side

'left'

In [86]:
print(
    tokenizer.decode(encodings["input_ids"][0])
)


<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|

In [76]:
tokenizer.special_tokens_map

{'eos_token': '<|im_end|>',
 'pad_token': '<|endoftext|>',
 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}

In [79]:
tokenizer.all_special_ids

[151645, 151643, 151644]

In [81]:
tokenizer.convert_ids_to_tokens(tokenizer.all_special_ids)

['<|im_end|>', '<|endoftext|>', '<|im_start|>']

# Policy Rollout 
*** 

In [95]:
generation_config = GenerationConfig(
    max_new_tokens=512,
    do_sample=True,
    top_p=0.9,
    top_k=50,
    temperature=0.7,
    num_return_sequences=4,
)

In [99]:
seq_ids, returns, action_mask, completions = rollout(
    model,
    tokenizer,
    batch_inputs=encodings,
    oracle_answers=samples["answer"],
    generation_config=generation_config,
)



RuntimeError: Placeholder storage has not been allocated on MPS device!