In [4]:
%load_ext autoreload
%autoreload 2
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import torch  
from torch.utils.data import DataLoader

from curious.utils.utils import move_paddings_to_right
from curious.config import TrainingConfig, RLConfig, WandbConfig, BaseConfig, SamplingConfig, RewardConfig, SFLConfig 
from curious.train.training_setup import set_up_training
from curious.train.trainer import PolicyGradientTrainer
from curious.sampling.sampling import sequences_log_probs 
from curious.policy_gradient.loss import ActorLoss
from curious.replay.experience import Experience, ReplayBuffer, join_experience_batch

PAD_TOKEN_ID = 128

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Utils 
*** 

In [5]:
attention_mask = torch.tensor(
    [
        [0, 0, 1, 1, 1],
        [0, 1, 1, 1, 1],
    ]
)

input_ids = torch.tensor(
    [
        [PAD_TOKEN_ID, PAD_TOKEN_ID, 1, 2, 3,],
        [PAD_TOKEN_ID, 1,            2, 3, 4,],
    ]
)


sequence_ids = torch.tensor(
    [
        [PAD_TOKEN_ID, PAD_TOKEN_ID, 1, 2, 3, 890, 891, PAD_TOKEN_ID, PAD_TOKEN_ID, PAD_TOKEN_ID],
        [PAD_TOKEN_ID, 1,            2, 3, 4, 899, 900, 911,          PAD_TOKEN_ID, PAD_TOKEN_ID],
    ]
)

move_paddings_to_right(input_ids, attention_mask, sequence_ids, PAD_TOKEN_ID)

(tensor([[  1,   2,   3, 890, 891, 128, 128],
         [  1,   2,   3,   4, 899, 900, 911]]),
 tensor([[False, False,  True,  True, False, False],
         [False, False, False,  True,  True,  True]]))

# Loading 
*** 

In [6]:
train_config = TrainingConfig(
    rl_config=RLConfig(
        group_size=3, 
        mini_batch_size=6,
        
    ),
    wandb_config=WandbConfig(),
    base_config=BaseConfig(
        train_batch_size=2,
        
    ),
    sampling_config=SamplingConfig(
        max_new_tokens=256,
    ),
    reward_config=RewardConfig(),
    sfl_config=SFLConfig(
    ),
)
#train_config.rl_config
training_setup, init_train_state = set_up_training(train_config)

#### Loading target policy Qwen/Qwen2-0.5B-Instruct ####
Applied Liger kernels to Qwen2
#### Loading dataset openai/gsm8k ####


Generating train split: 100%|██████████| 7473/7473 [00:00<00:00, 157982.45 examples/s]
Generating test split: 100%|██████████| 1319/1319 [00:00<00:00, 271363.47 examples/s]
Map: 100%|██████████| 7473/7473 [00:00<00:00, 16842.80 examples/s]
Map: 100%|██████████| 1319/1319 [00:00<00:00, 16191.05 examples/s]


Detected train_max_length: 212
Setting train_max_length to 212


Map: 100%|██████████| 7473/7473 [00:01<00:00, 4720.12 examples/s]


Detected test_max_length: 188
Setting test_max_length to 188


Map: 100%|██████████| 1319/1319 [00:00<00:00, 5350.43 examples/s]


#### Loading rollout data loader ####
#### Loading KL controller ####
#### Loading reference model ####
Applied Liger kernels to Qwen2
#### Defining actor loss ####
#### Defining reward model ####
#### Defining generation config ####
#### Defining evaluation config ####
#### Defining optimizer ####
#### Defining lr scheduler ####


In [7]:
init_train_state.keys()

dict_keys(['run_name', 'device', 'seed', 'model', 'optimizer', 'lr_scheduler', 'reference_model', 'kl_controller'])

In [8]:
trainer = PolicyGradientTrainer(training_setup) 

In [9]:
batch_inputs = next(iter(trainer.rollout_data_loader))

In [10]:
replay_buffer = trainer.collect_trajectories(
    init_train_state, 
    batch_inputs, 
    0 
)

W0511 09:25:09.122000 5754 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] Graph break from `Tensor.item()`, consider setting:
W0511 09:25:09.122000 5754 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W0511 09:25:09.122000 5754 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] or:
W0511 09:25:09.122000 5754 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0511 09:25:09.122000 5754 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] to include these operations in the captured graph.
W0511 09:25:09.122000 5754 /system/conda/miniconda3/envs/cloudspace/lib/python3.11/

In [11]:
exp = replay_buffer[0]

In [12]:
tokenizer = trainer.tokenizer

In [13]:
for key in exp.keys:
    val = getattr(exp, key)
    if val is not None:
        print(key, '->', val.shape)

sequences -> torch.Size([220])
action_log_probs -> torch.Size([219])
log_probs_ref -> torch.Size([219])
returns -> torch.Size([])
solved_mask -> torch.Size([])
advantages -> torch.Size([])
attention_mask -> torch.Size([220])
action_mask -> torch.Size([219])


In [14]:
print(tokenizer.decode(exp.sequences))

<|im_start|>system
Please reason step by step, and put your final answer within $\boxed{}$.<|im_end|>
<|im_start|>user
A pencil costs $0.5 each and a folder costs $0.9 each. An office needs two dozen pencils and 20 pieces of folders. How much does it cost to buy the office supplies?<|im_end|>
<|im_start|>assistant
Two dozen pencils is equal to 2 x 12 = 24 pencils.
The cost of the pencils is $0.5 x 24 = $12.
The cost of the folders is $0.9 x 20 = $18.
Therefore, the total cost to buy the office supplies is $12 + $18 = $30.
The answer is: $\boxed{30}$<|im_end|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endo

In [15]:
print(
    tokenizer.decode(
        exp.sequences[exp.attention_mask]
    )
)

<|im_start|>system
Please reason step by step, and put your final answer within $\boxed{}$.<|im_end|>
<|im_start|>user
A pencil costs $0.5 each and a folder costs $0.9 each. An office needs two dozen pencils and 20 pieces of folders. How much does it cost to buy the office supplies?<|im_end|>
<|im_start|>assistant
Two dozen pencils is equal to 2 x 12 = 24 pencils.
The cost of the pencils is $0.5 x 24 = $12.
The cost of the folders is $0.9 x 20 = $18.
Therefore, the total cost to buy the office supplies is $12 + $18 = $30.
The answer is: $\boxed{30}$<|im_end|>


In [16]:
print(exp.returns)

tensor(-1., dtype=torch.bfloat16)


In [17]:
print(exp.solved_mask)

tensor(0., dtype=torch.bfloat16)


In [18]:
print(exp.advantages)

tensor(0., dtype=torch.bfloat16)


In [19]:
exp1 = replay_buffer[1]
print(exp1.solved_mask)
print(exp1.returns)
print(exp1.advantages)
print(
    tokenizer.decode(
        exp1.sequences[exp1.attention_mask]
    )
)

tensor(0., dtype=torch.bfloat16)
tensor(-1., dtype=torch.bfloat16)
tensor(0., dtype=torch.bfloat16)
<|im_start|>system
Please reason step by step, and put your final answer within $\boxed{}$.<|im_end|>
<|im_start|>user
A pencil costs $0.5 each and a folder costs $0.9 each. An office needs two dozen pencils and 20 pieces of folders. How much does it cost to buy the office supplies?<|im_end|>
<|im_start|>assistant
 Two dozen pencils is equal to 2 x 12 = 24 pencils.
If each pencil costs $0.5, then the total cost for the pencils is 24 x $0.5 = $12.
The office needs 20 folders, and each folder costs $0.9, so the total cost for the folders is 20 x $0.9 = $18.
To find the total cost for the office supplies, we add the cost of the pencils and the cost of the folders, so $12 + $18 = $30.
The answer is $\boxed{30}$.<|im_end|>


In [20]:
exp2 = replay_buffer[2]
print(exp2.solved_mask)
print(exp2.returns)
print(exp2.advantages)
print(
    tokenizer.decode(
        exp2.sequences[exp2.attention_mask]
    )
)

tensor(0., dtype=torch.bfloat16)
tensor(-1., dtype=torch.bfloat16)
tensor(0., dtype=torch.bfloat16)
<|im_start|>system
Please reason step by step, and put your final answer within $\boxed{}$.<|im_end|>
<|im_start|>user
A pencil costs $0.5 each and a folder costs $0.9 each. An office needs two dozen pencils and 20 pieces of folders. How much does it cost to buy the office supplies?<|im_end|>
<|im_start|>assistant
 Two dozen pencils is equal to 2 x 12 = 24 pencils.
The cost of one pencil is $0.5, so the cost of 24 pencils is 24 x $0.5 = $12.
The cost of one folder is $0.9, so the cost of 20 folders is 20 x $0.9 = $18.
To find the total cost, we add the cost of the pencils and the cost of the folders, so $12 + $18 = $30.
The answer is $\boxed{30}$.<|im_end|>


In [21]:
exp3 = replay_buffer[3]
print(exp3.solved_mask)
print(exp3.returns)
print(exp3.advantages)
print(
    tokenizer.decode(
        exp3.sequences[exp3.attention_mask]
    )
)

tensor(0., dtype=torch.bfloat16)
tensor(-1., dtype=torch.bfloat16)
tensor(-0.5781, dtype=torch.bfloat16)
<|im_start|>system
Please reason step by step, and put your final answer within $\boxed{}$.<|im_end|>
<|im_start|>user
Jackson had 20 kilograms of meat. He used 1/4 of the meat to make meatballs and used 3 kilograms of meat to make spring rolls. How many kilograms of meat are left?<|im_end|>
<|im_start|>assistant
Jackson used 20*1/4 = 5 kilograms of meat for meatballs.
Jackson used 3 kilograms of meat for spring rolls.
So, Jackson used 5+3 = 8 kilograms of meat in total.
Jackson has 20-8 = 12 kilograms of meat left.
The answer is: $\boxed{12}$<|im_end|>


In [22]:
print(
    tokenizer.decode(
        exp3.sequences[1:][exp3.action_mask]
    )
)

Jackson used 20*1/4 = 5 kilograms of meat for meatballs.
Jackson used 3 kilograms of meat for spring rolls.
So, Jackson used 5+3 = 8 kilograms of meat in total.
Jackson has 20-8 = 12 kilograms of meat left.
The answer is: $\boxed{12}$<|im_end|>


In [23]:
experience_sampler = DataLoader(
    replay_buffer,
    batch_size=trainer.rl_config.mini_batch_size,
    shuffle= False,
    drop_last=False,
    collate_fn=join_experience_batch,
    num_workers=trainer.base_config.num_workers,
)

In [24]:
next_exp = next(iter(experience_sampler))

In [25]:
next_exp.returns.shape, next_exp.advantages.shape 

(torch.Size([6]), torch.Size([6]))

In [None]:
for key in next_exp.keys:
    val = getattr(next_exp, key)
    if val is not None:
        print(key, '->', val.shape)

sequences -> torch.Size([6, 220])
action_log_probs -> torch.Size([6, 219])
log_probs_ref -> torch.Size([6, 219])
returns -> torch.Size([6])
solved_mask -> torch.Size([6])
advantages -> torch.Size([6])
attention_mask -> torch.Size([6, 220])
action_mask -> torch.Size([6, 219])


In [27]:
next_exp = next_exp.to(init_train_state["model"].device) 
log_probs, _ = sequences_log_probs(
    init_train_state["model"], 
    sequence_ids= next_exp.sequences, 
    attention_mask= next_exp.attention_mask,
    return_entropy=False,
    logits_minibatch_size=trainer.rl_config.logits_minibatch_size,
)

In [28]:
print(log_probs.shape)

torch.Size([6, 219])


In [29]:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

loss, mean_kl, mean_actor_loss = trainer.actor_loss(
    log_probs=log_probs, 
    experience= next_exp
)

In [30]:
print(loss)
print(mean_kl)
print(mean_actor_loss)

tensor(0.0003, device='cuda:0', grad_fn=<CompiledFunctionBackward>)
tensor(0.0001, device='cuda:0', grad_fn=<CompiledFunctionBackward>)
tensor(-0.0864, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<CompiledFunctionBackward>)
