In [29]:
%load_ext autoreload
%autoreload 2
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import torch  
from torch.utils.data import DataLoader

from curious.utils.utils import move_paddings_to_right
from curious.config import TrainingConfig, RLConfig, WandbConfig, BaseConfig, SamplingConfig, RewardConfig, SFLConfig 
from curious.train.training_setup import set_up_training
from curious.train.trainer import PolicyGradientTrainer
from curious.sampling.sampling import sequences_log_probs 
from curious.policy_gradient.loss import ActorLoss, masked_mean 
from curious.replay.experience import Experience, ReplayBuffer, join_experience_batch

PAD_TOKEN_ID = 128

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Utils 
*** 

In [None]:
attention_mask = torch.tensor(
    [
        [0, 0, 1, 1, 1],
        [0, 1, 1, 1, 1],
    ]
)

input_ids = torch.tensor(
    [
        [PAD_TOKEN_ID, PAD_TOKEN_ID, 1, 2, 3,],
        [PAD_TOKEN_ID, 1,            2, 3, 4,],
    ]
)


sequence_ids = torch.tensor(
    [
        [PAD_TOKEN_ID, PAD_TOKEN_ID, 1, 2, 3, 890, 891, PAD_TOKEN_ID, PAD_TOKEN_ID, PAD_TOKEN_ID],
        [PAD_TOKEN_ID, 1,            2, 3, 4, 899, 900, 911,          PAD_TOKEN_ID, PAD_TOKEN_ID],
    ]
)

move_paddings_to_right(input_ids, attention_mask, sequence_ids, PAD_TOKEN_ID)

(tensor([[  1,   2,   3, 890, 891, 128, 128],
         [  1,   2,   3,   4, 899, 900, 911]]),
 tensor([[False, False,  True,  True, False, False],
         [False, False, False,  True,  True,  True]]))

# Loading 
*** 

In [3]:
train_config = TrainingConfig(
    rl_config=RLConfig(
        group_size=3, 
        mini_batch_size=6,
        kl_weight=0,
        
    ),
    wandb_config=WandbConfig(),
    base_config=BaseConfig(
        train_batch_size=2,
        
    ),
    sampling_config=SamplingConfig(
        max_new_tokens=256,
    ),
    reward_config=RewardConfig(),
    sfl_config=SFLConfig(
    ),
)
#train_config.rl_config
training_setup, init_train_state = set_up_training(train_config)

#### Loading target policy Qwen/Qwen2-0.5B-Instruct ####
Applied Liger kernels to Qwen2
#### Loading dataset openai/gsm8k ####


Generating train split: 100%|██████████| 7473/7473 [00:00<00:00, 250804.44 examples/s]
Generating test split: 100%|██████████| 1319/1319 [00:00<00:00, 21209.50 examples/s]
Map: 100%|██████████| 7473/7473 [00:00<00:00, 28762.80 examples/s]
Map: 100%|██████████| 1319/1319 [00:00<00:00, 28585.31 examples/s]


Detected train_max_length: 212
Setting train_max_length to 212


Map: 100%|██████████| 7473/7473 [00:00<00:00, 7906.18 examples/s]


Detected test_max_length: 188
Setting test_max_length to 188


Map: 100%|██████████| 1319/1319 [00:00<00:00, 8570.12 examples/s]

#### Loading rollout data loader ####
#### Defining actor loss ####
#### Defining reward model ####
#### Defining generation config ####
#### Defining evaluation config ####
#### Defining optimizer ####
#### Defining lr scheduler ####





In [None]:
init_train_state.keys()

dict_keys(['run_name', 'device', 'seed', 'model', 'optimizer', 'lr_scheduler', 'reference_model', 'kl_controller'])

In [5]:
trainer = PolicyGradientTrainer(training_setup) 

In [6]:
batch_inputs = next(iter(trainer.rollout_data_loader))

In [7]:
replay_buffer = trainer.collect_trajectories(
    init_train_state, 
    batch_inputs, 
    0 
)

W0511 17:05:28.711000 4901 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] Graph break from `Tensor.item()`, consider setting:
W0511 17:05:28.711000 4901 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W0511 17:05:28.711000 4901 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] or:
W0511 17:05:28.711000 4901 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0511 17:05:28.711000 4901 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] to include these operations in the captured graph.
W0511 17:05:28.711000 4901 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/

In [8]:
exp = replay_buffer[0]

In [9]:
tokenizer = trainer.tokenizer

In [None]:
for key in exp.keys:
    val = getattr(exp, key)
    if val is not None:
        print(key, '->', val.shape)

sequences -> torch.Size([278])
action_log_probs -> torch.Size([277])
returns -> torch.Size([])
solved_mask -> torch.Size([])
advantages -> torch.Size([])
attention_mask -> torch.Size([278])
action_mask -> torch.Size([277])


In [None]:
print(tokenizer.decode(exp.sequences))

<|im_start|>system
Please reason step by step, and put your final answer within $\boxed{}$.<|im_end|>
<|im_start|>user
Nala found 5 seashells at the beach. The next day, she found another 7, and the following day, she found twice the number she got from the first two days. How many seashells does Nala have?<|im_end|>
<|im_start|>assistant
Nala found 5 seashells on the first day and 7 seashells on the second day. So, the total number of seashells she has found so far is $5 + 7 = 12$.

On the following day, she found twice the number she got from the first two days. So, the number of seashells she found on the second day is $2 \times 12 = 24$.

To find the total number of seashells Nala has, we add the number she found on the first two days and the number she found on the following day. That is $12 + 24 = 36$.
The answer is: $\boxed{36}$<|im_end|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>

In [None]:
print(
    tokenizer.decode(
        exp.sequences[exp.attention_mask]
    )
)

<|im_start|>system
Please reason step by step, and put your final answer within $\boxed{}$.<|im_end|>
<|im_start|>user
Nala found 5 seashells at the beach. The next day, she found another 7, and the following day, she found twice the number she got from the first two days. How many seashells does Nala have?<|im_end|>
<|im_start|>assistant
Nala found 5 seashells on the first day and 7 seashells on the second day. So, the total number of seashells she has found so far is $5 + 7 = 12$.

On the following day, she found twice the number she got from the first two days. So, the number of seashells she found on the second day is $2 \times 12 = 24$.

To find the total number of seashells Nala has, we add the number she found on the first two days and the number she found on the following day. That is $12 + 24 = 36$.
The answer is: $\boxed{36}$<|im_end|>


In [None]:
print(exp.returns)

tensor(1., dtype=torch.bfloat16)


In [None]:
print(exp.solved_mask)

tensor(1., dtype=torch.bfloat16)


In [None]:
print(exp.advantages)

tensor(1.1172, dtype=torch.bfloat16)


In [None]:
exp1 = replay_buffer[1]
print(exp1.solved_mask)
print(exp1.returns)
print(exp1.advantages)
print(
    tokenizer.decode(
        exp1.sequences[exp1.attention_mask]
    )
)

tensor(0., dtype=torch.bfloat16)
tensor(-0.5000, dtype=torch.bfloat16)
tensor(-0.3203, dtype=torch.bfloat16)
<|im_start|>system
Please reason step by step, and put your final answer within $\boxed{}$.<|im_end|>
<|im_start|>user
Nala found 5 seashells at the beach. The next day, she found another 7, and the following day, she found twice the number she got from the first two days. How many seashells does Nala have?<|im_end|>
<|im_start|>assistant
 Nala found 5 seashells on the first day.
She found 7 seashells on the second day.
Let's call the number of seashells she found on the third day $x$.
So, the total number of seashells she has after three days is $5 + 7 + x$.
We know that the total number of seashells she found is 5 + 7 + $x$.
So, we have the equation $5 + 7 + x = 5 + 7 + x + x$.
Simplifying the equation, we get $12 + x = 12 + 2x$.
Subtracting $12$ from both sides gives $x = 2x$.
So, the number of seashells she found each day is $x = 2$.
Therefore, Nala has 2 seashells.
The answ

In [None]:
exp2 = replay_buffer[2]
print(exp2.solved_mask)
print(exp2.returns)
print(exp2.advantages)
print(
    tokenizer.decode(
        exp2.sequences[exp2.attention_mask]
    )
)

tensor(0., dtype=torch.bfloat16)
tensor(-1., dtype=torch.bfloat16)
tensor(-0.8008, dtype=torch.bfloat16)
<|im_start|>system
Please reason step by step, and put your final answer within $\boxed{}$.<|im_end|>
<|im_start|>user
Nala found 5 seashells at the beach. The next day, she found another 7, and the following day, she found twice the number she got from the first two days. How many seashells does Nala have?<|im_end|>
<|im_start|>assistant
To find out how many seashells Nala has in total, we first need to calculate the number of seashells she found each day:

1. On the first day, she found 5 seashells.
2. On the second day, she found 7 more seashells than the first day.
3. On the third day, she found twice the number she got from the first two days, which means \(2 \times 5 + 7 = 10 + 7 = 17\) seashells.

Adding these up, we get the total number of seashells Nala has:

\(5 + 7 + 17 = 39\) seashells.<|im_end|>


In [None]:
exp3 = replay_buffer[3]
print(exp3.solved_mask)
print(exp3.returns)
print(exp3.advantages)
print(
    tokenizer.decode(
        exp3.sequences[exp3.attention_mask]
    )
)

tensor(0., dtype=torch.bfloat16)
tensor(-1., dtype=torch.bfloat16)
tensor(0., dtype=torch.bfloat16)
<|im_start|>system
Please reason step by step, and put your final answer within $\boxed{}$.<|im_end|>
<|im_start|>user
Zig wrote four times as many books as Flo. If Zig wrote 60 books, how many books did they write altogether?<|im_end|>
<|im_start|>assistant
 If Zig wrote 60 books, and Zig wrote four times as many books as Flo, then Flo wrote 60/4 = 15 books.
To find out how many books they wrote altogether, we add the number of books Zig wrote and the number of books Flo wrote. So, 60 + 15 = 75 books.
So, they wrote 75 books altogether.
The answer is $\boxed{75}$.<|im_end|>


In [None]:
print(
    tokenizer.decode(
        exp3.sequences[1:][exp3.action_mask]
    )
)

 If Zig wrote 60 books, and Zig wrote four times as many books as Flo, then Flo wrote 60/4 = 15 books.
To find out how many books they wrote altogether, we add the number of books Zig wrote and the number of books Flo wrote. So, 60 + 15 = 75 books.
So, they wrote 75 books altogether.
The answer is $\boxed{75}$.<|im_end|>


In [20]:
experience_sampler = DataLoader(
    replay_buffer,
    batch_size=trainer.rl_config.mini_batch_size,
    shuffle= False,
    drop_last=False,
    collate_fn=join_experience_batch,
    num_workers=trainer.base_config.num_workers,
)

In [21]:
next_exp = next(iter(experience_sampler))

In [None]:
next_exp.returns.shape, next_exp.advantages.shape 

(torch.Size([6]), torch.Size([6]))

In [None]:
for key in next_exp.keys:
    val = getattr(next_exp, key)
    if val is not None:
        print(key, '->', val.shape)

sequences -> torch.Size([6, 278])
action_log_probs -> torch.Size([6, 277])
returns -> torch.Size([6])
solved_mask -> torch.Size([6])
advantages -> torch.Size([6])
attention_mask -> torch.Size([6, 278])
action_mask -> torch.Size([6, 277])


In [24]:
next_exp = next_exp.to(init_train_state["model"].device) 
log_probs, _ = sequences_log_probs(
    init_train_state["model"], 
    sequence_ids= next_exp.sequences, 
    attention_mask= next_exp.attention_mask,
    return_entropy=False,
    logits_minibatch_size=trainer.rl_config.logits_minibatch_size,
)

In [None]:
print(log_probs.shape)

torch.Size([6, 277])


In [26]:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

loss, mean_kl, mean_actor_loss = trainer.actor_loss(
    log_probs=log_probs, 
    experience= next_exp
)

In [None]:
print(loss)
print(mean_kl)
print(mean_actor_loss)

tensor(0.0010, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<CompiledFunctionBackward>)
0
tensor([[-0.4805, -1.3438, -1.2266,  ..., -1.1172, -1.1172, -1.1172],
        [ 0.2559,  0.4121,  0.3516,  ...,  0.3203,  0.3203,  0.3203],
        [ 0.6406,  1.0312,  0.8789,  ...,  0.8008,  0.8008,  0.8008],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<CompiledFunctionBackward>)


In [28]:
print(mean_actor_loss.shape)

torch.Size([6, 277])


In [32]:
masked_mean(
    mean_actor_loss,
    mask = next_exp.action_mask,
    dim = -1,
    use_fixed_response_length = False 
).mean()

tensor(0.0010, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)