In [19]:
import torch
from reasoning_from_scratch.ch02 import get_device
from reasoning_from_scratch.ch03 import load_model_and_tokenizer
from reasoning_from_scratch.ch03 import render_prompt
from reasoning_from_scratch.ch04 import (
generate_text_stream_concat_flex,
generate_text_top_p_stream_cache
)
import json
from pathlib import Path
import requests
from pprint import pprint
from reasoning_from_scratch.qwen3 import KVCache
from reasoning_from_scratch.ch04 import top_p_filter
from reasoning_from_scratch.ch03 import (
extract_final_candidate, grade_answer
)


In [4]:
device = get_device()
model, tokenizer = load_model_and_tokenizer('base', device, local_dir='../models/qwen3_base', use_compile=False) 

Using NVIDIA CUDA GPU
âœ“ ..\models\qwen3_base\qwen3-0.6B-base.pth already up-to-date


  model.load_state_dict(torch.load(model_path))


In [6]:
raw_prompt =  (
"Half the value of $3x-9$ is $x+37$. "
"What is the value of $x$?"
)

prompt = render_prompt(raw_prompt)

torch.manual_seed(0)
response = generate_text_stream_concat_flex(model,
                                            tokenizer, prompt, device, max_new_tokens=2048, verbose=True, 
                                            generate_func=generate_text_top_p_stream_cache,
                                            temperature = 0.9, top_p = 0.9)

 \boxed{38}

In [8]:
def load_math_train(local_path = 'math_train.json', save_copy=True):
    local_path = Path(local_path)
    url = (
        "https://raw.githubusercontent.com/rasbt/"
        "math_full_minus_math500/refs/heads/main/"
        "math_full_minus_math500.json"
    )

    if local_path.exists():
        with local_path.open('r', encoding='utf-8') as f:
            data = json.load(f)
    else:
        r = requests.get(url, timeout=30)
        r.raise_for_status()
        data = r.json()

        if save_copy:
            with local_path.open('w', encoding='utf-8') as f:
                json.dump(data, f, indent=2)
    
    return data

In [9]:
math_train = load_math_train(save_copy=True)
print(f'Dataset size: {len(math_train)}')

Dataset size: 12000


In [11]:
pprint(math_train[4])

{'answer': '6',
 'level': 'Level 3',
 'problem': 'Sam is hired for a 20-day period. On days that he works, he earns '
            '$\\$$60. For each day that he does not work, $\\$$30 is '
            'subtracted from his earnings. At the end of the 20-day period, he '
            'received $\\$$660. How many days did he not work?',
 'solution': 'Call $x$ the number of days Sam works and $y$ the number of days '
             'he does not. We can set up the following system of equations to '
             'represent the given information: \\begin{align*}\n'
             'x+y &= 20 \\\\\n'
             '60x - 30y &= 660 \\\\\n'
             '\\end{align*} The first equation represents the total number of '
             'days Sam works, and the second equation represents his total '
             'profit. Solving for $x$ in the first equation yields $x = 20 - '
             'y$. Substituting into the second equation gives $60(20-y) - 30y '
             '= 660$. Canceling a factor of $10$ an

We could use 'solution' field to validate step-by-step reasoning of the model but this would lead to overfitting; instead, wewant the model to generalize its reasoning ability without memorizing this specific approach.

# Implementing GRPO PHASES

**PHASE 1** 
_GENEARTING & SAMPLING ROLLOUTS_

We could use the **generate_text_stream_concat_flex** to generate responses, but this was decorated with @inference_only(), making it useless for training purposes


In [15]:
@torch.no_grad()
def sample_response(
    model, 
    tokenizer, 
    prompt, 
    device,
    max_new_tokens=512, 
    temperature = 0.8,
    top_p = 0.9
):
    input_ids = torch.tensor(tokenizer.encode(prompt), device=device)
    cache = KVCache(n_layers=model.cfg['n_layers'])
    model.reset_kv_cache()
    next_token_logits = model(input_ids.unsqueeze(0), cache=cache)[:, -1]

    generated = []
    for _ in range(max_new_tokens):
        if temperature and temperature != 1.0: 
            next_token_logits /= temperature

        next_token_probas = torch.softmax(next_token_logits, dim=-1)
        next_token_probas = top_p_filter(next_token_probas, top_p)
        next_token_id = torch.multinomial(
            next_token_probas.cpu(), num_samples=1
        ).to(device)

        if (tokenizer.eos_token_id is not None and next_token_id.item() == tokenizer.eos_token_id):
            break

        generated.append(next_token_id.item())
        next_token_logits = model(next_token_id, cache=cache)[:, -1]
    
    full_token_ids = torch.cat([input_ids, torch.tensor(generated, device=device, dtype=input_ids.dtype)])

    return full_token_ids, input_ids.numel(), tokenizer.decode(generated) #return full answer + prompt, number of token of prompt and generated text



In [16]:
torch.manual_seed(0)
raw_prompt = (
"Half the value of $3x-9$ is $x+37$. "
"What is the value of $x$?"
)
prompt = render_prompt(raw_prompt)
token_ids, prompt_len, answer_text = sample_response(
        model=model,
        tokenizer=tokenizer,
        prompt=prompt,
        device=device,
        max_new_tokens=512,
        temperature=0.9,
        top_p=0.9,
    )
print(answer_text)

 \boxed{38}


In [18]:
rollouts = []

for _ in range(4):
    rollouts.append(sample_response(
        model=model,
        tokenizer=tokenizer,
        prompt=prompt,
        device=device,
        max_new_tokens=512,
        temperature=0.9,
        top_p=0.9,
    )[2])

print(rollouts)

[' \\boxed{42}', ' To solve the equation, we start by simplifying the given equation.\n\n\\[\n\\frac{1}{2}(3x - 9) = x + 37\n\\]\n\nFirst, distribute the \\(\\frac{1}{2}\\) on the left side:\n\n\\[\n\\frac{3x}{2} - \\frac{9}{2} = x + 37\n\\]\n\nNext, eliminate the fraction by multiplying every term by 2:\n\n\\[\n3x - 9 = 2x + 74\n\\]\n\nNow, isolate \\(x\\) by subtracting \\(2x\\) from both sides:\n\n\\[\n3x - 2x - 9 = 74\n\\]\n\nSimplify:\n\n\\[\nx - 9 = 74\n\\]\n\nAdd 9 to both sides to solve for \\(x\\):\n\n\\[\nx = 74 + 9\n\\]\n\n\\[\nx = 83\n\\]\n\nThus, the value of \\(x\\) is \\(\\boxed{83}\\).', ' \\boxed{32}', ' To solve the equation "Half the value of \\(3x - 9\\) is \\(x + 37\\)", follow these steps:\n\n1. **Translate the problem into an equation:**\n   \n   \\[\n   \\frac{1}{2}(3x - 9) = x + 37\n   \\]\n\n2. **Eliminate the fraction by multiplying both sides by 2:**\n   \n   \\[\n   2 \\times \\frac{1}{2}(3x - 9) = 2 \\times (x + 37)\n   \\]\n   \n   \\[\n   3x - 9 = 2x + 7

**PHASE 2** 
_CALCULATING REWARDS_

In [23]:
def reward_rlvr(answer_text, gt_text):
    extracted = extract_final_candidate(answer_text, fallback=None) #we require /boxed format
    if not extracted: #not boxed, we can still assign 0.5 if the number was correct but not boxed
        extracted_not_boxed = extract_final_candidate(answer_text, fallback='number_only')

        return float(grade_answer(extracted_not_boxed, gt_text)) / 2
            

    correct = grade_answer(extracted, gt_text) #True or False
    return float(correct)

In [30]:
rollouts_rewards = []

for rollout in rollouts:
    reward = reward_rlvr(rollout, '83')
    print(f"Answer: {rollout!r}")
    print(f"Reward: {reward}\n")

    rollouts_rewards.append(reward)

Answer: ' \\boxed{42}'
Reward: 0.0

Answer: ' To solve the equation, we start by simplifying the given equation.\n\n\\[\n\\frac{1}{2}(3x - 9) = x + 37\n\\]\n\nFirst, distribute the \\(\\frac{1}{2}\\) on the left side:\n\n\\[\n\\frac{3x}{2} - \\frac{9}{2} = x + 37\n\\]\n\nNext, eliminate the fraction by multiplying every term by 2:\n\n\\[\n3x - 9 = 2x + 74\n\\]\n\nNow, isolate \\(x\\) by subtracting \\(2x\\) from both sides:\n\n\\[\n3x - 2x - 9 = 74\n\\]\n\nSimplify:\n\n\\[\nx - 9 = 74\n\\]\n\nAdd 9 to both sides to solve for \\(x\\):\n\n\\[\nx = 74 + 9\n\\]\n\n\\[\nx = 83\n\\]\n\nThus, the value of \\(x\\) is \\(\\boxed{83}\\).'
Reward: 1.0

Answer: ' \\boxed{32}'
Reward: 0.0

Answer: ' To solve the equation "Half the value of \\(3x - 9\\) is \\(x + 37\\)", follow these steps:\n\n1. **Translate the problem into an equation:**\n   \n   \\[\n   \\frac{1}{2}(3x - 9) = x + 37\n   \\]\n\n2. **Eliminate the fraction by multiplying both sides by 2:**\n   \n   \\[\n   2 \\times \\frac{1}{2}(3x

We could decide to train the model with an additional intermediate reward during the generation process (like PPO), but these experiments were unsuccessful

**PHASE 3** 
_CALCULATING ADVANTAGES_

In [45]:
rewards = torch.tensor(rollouts_rewards, device=device)
advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards, dim=-1) + 1e-4)

print(rewards)
print(advantages)

tensor([0., 1., 0., 1.], device='cuda:0')
tensor([-0.8659,  0.8659, -0.8659,  0.8659], device='cuda:0')


**PHASE 4** 
_CALCULATING LOGPROBS_

In [37]:
def sequence_logprob_answer(model, tokenizer, prompt, answer, device='cpu'):

    prompt_ids = tokenizer.encode(prompt)
    answer_ids = tokenizer.encode(answer)
    full_ids = torch.tensor(prompt_ids + answer_ids, device=device)

    logits = model(full_ids.unsqueeze(0)).squeeze(0)
    logprobs = torch.log_softmax(logits, dim=-1)

    start = len(prompt_ids) - 1
    end = full_ids.shape[0] - 1

    t_idx = torch.arange(start, end, device=device)
    next_tokens = full_ids[start + 1:end + 1]
    next_token_logprobs = logprobs[t_idx, next_tokens]

    return torch.sum(next_token_logprobs) #sum instead of mean, like it was for token-level logprob

In [38]:
avg_logprob_val = sequence_logprob_answer(
    model, tokenizer,
    prompt=prompt,
    answer=answer_text,
    device=device)
print(avg_logprob_val)

tensor(-5.7500, device='cuda:0', dtype=torch.bfloat16, grad_fn=<SumBackward0>)


We now build the same function, but already with token_ids and prompt_len as input fields, since they are geenratyed buy the **sample_response** function written before

In [39]:
def sequence_logprob_draft(model, token_ids, prompt_len):

    logits = model(token_ids.unsqueeze(0)).squeeze(0)
    logprobs = torch.log_softmax(logits, dim=-1)

    start = prompt_len - 1
    end = token_ids.shape[0] - 1

    t_idx = torch.arange(start, end, device=device)
    next_tokens = token_ids[start + 1:end + 1]
    next_token_logprobs = logprobs[t_idx, next_tokens]

    return torch.sum(next_token_logprobs) #sum instead of mean, like it was for token-level logprob

In [40]:
print(sequence_logprob_draft(model, token_ids, prompt_len))

tensor(-5.7500, device='cuda:0', dtype=torch.bfloat16, grad_fn=<SumBackward0>)


As we can see, the SumBackward0 entry is present, this shows that the summation of token level logprobs are part og the computational graph. We need exactly this when updating the model weights to minimize the policy gradient loss.

We can also implement a more efficient version of sequence_logprob, using **gather** function instead of manually indexig the input_ids

In [41]:
def sequence_logprob(model, token_ids, prompt_len):
    logits = model(token_ids.unsqueeze(0)).squeeze(0).float()
    logprobs = torch.log_softmax(logits, dim=-1)
    
    selected = logprobs[:-1].gather(
        1, token_ids[1:].unsqueeze(-1)
    ).squeeze(-1)

    return torch.sum(selected[prompt_len - 1:])

In [42]:
print(sequence_logprob(model, token_ids, prompt_len))

tensor(-5.7268, device='cuda:0', grad_fn=<SumBackward0>)


In [43]:
rollout_logps = []

for text in rollouts:
    token_ids = tokenizer.encode(prompt + ' ' + text)
    logprob = sequence_logprob(
        model, token_ids=torch.tensor(token_ids, device=device),
        prompt_len=prompt_len
    )
    print(f"Answer: {text}")
    print(f"Logprob: {logprob.item():.4f}\n")
    rollout_logps.append(logprob)


Answer:  \boxed{42}
Logprob: -10.7715

Answer:  To solve the equation, we start by simplifying the given equation.

\[
\frac{1}{2}(3x - 9) = x + 37
\]

First, distribute the \(\frac{1}{2}\) on the left side:

\[
\frac{3x}{2} - \frac{9}{2} = x + 37
\]

Next, eliminate the fraction by multiplying every term by 2:

\[
3x - 9 = 2x + 74
\]

Now, isolate \(x\) by subtracting \(2x\) from both sides:

\[
3x - 2x - 9 = 74
\]

Simplify:

\[
x - 9 = 74
\]

Add 9 to both sides to solve for \(x\):

\[
x = 74 + 9
\]

\[
x = 83
\]

Thus, the value of \(x\) is \(\boxed{83}\).
Logprob: -36.6990

Answer:  \boxed{32}
Logprob: -10.3265

Answer:  To solve the equation "Half the value of \(3x - 9\) is \(x + 37\)", follow these steps:

1. **Translate the problem into an equation:**
   
   \[
   \frac{1}{2}(3x - 9) = x + 37
   \]

2. **Eliminate the fraction by multiplying both sides by 2:**
   
   \[
   2 \times \frac{1}{2}(3x - 9) = 2 \times (x + 37)
   \]
   
   \[
   3x - 9 = 2x + 74
   \]

3. **Isolate t

**PHASE 5** 
_CALCULATING POLICY GRADIENTS LOSS_

In [46]:
logps = torch.stack(rollout_logps) #transform to tensor
pg_loss = - (advantages.detach() * logps).mean() #detach prevent the gradients from flowing back through the advanatage calculation, since it is a fixed parameter (not learnable)

print(logps)
print(pg_loss)

tensor([-10.7715, -36.6990, -10.3265, -26.8390], device='cuda:0',
       grad_fn=<StackBackward0>)
tensor(9.1869, device='cuda:0', grad_fn=<NegBackward0>)


**PUTTING EVERYTHING TOGETHER** 

In [None]:
def compute_grpo_loss(
    model,
    tokenizer,
    example,
    device,
    num_rollouts=2,
    max_new_tokens=256,
    temperature=0.8,
    top_p=0.9,
):
    assert num_rollouts >= 2
    roll_logps, roll_rewards, samples = [], [], []
    prompt = render_prompt(example["problem"]) #take the prompt
    was_training = model.training
    model.eval()
    for _ in range(num_rollouts):

        token_ids, prompt_len, text = sample_response(
            model=model,
            tokenizer=tokenizer,
            prompt=prompt,
            device=device,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
        ) #sample a single response

        reward = reward_rlvr(text, example["answer"]) #calculate reward

        logp = sequence_logprob(model, token_ids, prompt_len) #calculate sequence_logprobs
        roll_logps.append(logp)
        roll_rewards.append(reward)
        samples.append(
            {
                "text": text,
                "reward": reward,
                "gen_len": token_ids.numel() - prompt_len,
            }
        )
    if was_training:
        model.train()

    rewards = torch.tensor(roll_rewards, device=device)

    advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-4) #calculate advantages

    logps = torch.stack(roll_logps)
    
    pg_loss = -(advantages.detach() * logps).mean() #calculate policy gradient loss
    loss = pg_loss  # after we'll add a KL term here
    return {
        "loss": loss.item(),
        "pg_loss": pg_loss.item(),
        "rewards": roll_rewards,
        "advantages": advantages.detach().cpu().tolist(),
        "samples": samples,
        "loss_tensor": loss,
    }

In [48]:
torch.manual_seed(123)
stats = compute_grpo_loss(
    model=model,
    tokenizer=tokenizer,
    example=math_train[4],
    device=device,
    num_rollouts=2,
    max_new_tokens=256,
    temperature=0.8,
    top_p=0.9
)
pprint(stats)


{'advantages': [0.0, 0.0],
 'loss': -0.0,
 'loss_tensor': tensor(-0., device='cuda:0', grad_fn=<NegBackward0>),
 'pg_loss': -0.0,
 'rewards': [0.0, 0.0],
 'samples': [{'gen_len': 4, 'reward': 0.0, 'text': ' 10 days'},
             {'gen_len': 256,
              'reward': 0.0,
              'text': ' \n'
                      'Sam earned a total of $660$ dollars at the end of the '
                      '20-day period. He earns $60$ dollars per day if he '
                      'works, and loses $30$ dollars per day if he does not '
                      "work. Let's denote the number of days Sam worked as $w$ "
                      'and the number of days he did not work as $d$. We know '
                      'that $w + d = 20$ (since he worked for 20 days in '
                      'total).\n'
                      '\n'
                      "Sam's total earnings can be expressed as:\n"
                      '\\[60w - 30d = 660\\]\n'
                      '\n'
                      

When all rewards are 0.0, so they are the advantages and so the loss. In this case, the weight eould not be updated at all.

In [50]:
import time

def train_rlvr_grpo(
    model,
    tokenizer,
    math_data,
    device,
    steps=None,
    num_rollouts=2,
    max_new_tokens=256,
    temperature=0.8,
    top_p=0.9,
    lr=1e-5,
    checkpoint_every=50,
    checkpoint_dir=".",
    csv_log_path=None,
):
    if steps is None:
        steps = len(math_data)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    current_step = 0
    if csv_log_path is None:
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        csv_log_path = f"train_rlvr_grpo_metrics_{timestamp}.csv"
    csv_log_path = Path(csv_log_path)
    try:
        for step in range(steps):
            optimizer.zero_grad()
            current_step = step + 1
            example = math_data[step % len(math_data)]
            stats = compute_grpo_loss(
                model=model,
                tokenizer=tokenizer,
                example=example,
                device=device,
                num_rollouts=num_rollouts,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
            )
            stats["loss_tensor"].backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            reward_avg = torch.tensor(stats["rewards"]).mean().item()
            step_tokens = sum(
                sample["gen_len"] for sample in stats["samples"]
            )
            avg_response_len = (
                step_tokens / len(stats["samples"]) if stats["samples"] else 0.0
            )
            append_csv_metrics(
                csv_log_path, current_step, steps,
                stats["loss"], reward_avg, avg_response_len,
            )
            print(
                f"[Step {current_step}/{steps}] "
                f"loss={stats['loss']:.4f} "
                f"reward_avg={reward_avg:.3f} "
                f"avg_resp_len={avg_response_len:.1f}"
            )
            if checkpoint_every and current_step % checkpoint_every == 0:
                ckpt_path = save_checkpoint(
                    model=model,
                    checkpoint_dir=checkpoint_dir,
                    step=current_step,
                )
                print(f"Saved checkpoint to {ckpt_path}")
    except KeyboardInterrupt:
        ckpt_path = save_checkpoint(
            model=model,
            checkpoint_dir=checkpoint_dir,
            step=max(1, current_step),
            suffix="interrupt",
        )
        print(f"\nKeyboardInterrupt. Saved checkpoint to {ckpt_path}")
        return model
    return model

def save_checkpoint(model, checkpoint_dir, step, suffix=""):
    checkpoint_dir = Path(checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    suffix = f"-{suffix}" if suffix else ""
    ckpt_path = (
        checkpoint_dir /
        f"qwen3-0.6B-rlvr-grpo-step{step:05d}{suffix}.pth"
    )
    torch.save(model.state_dict(), ckpt_path)
    return ckpt_path

def append_csv_metrics(
    csv_log_path, step_idx, total_steps,
    loss, reward_avg, avg_response_len,
):
    if not csv_log_path.exists():
        csv_log_path.write_text(
            "step,total_steps,loss,reward_avg,avg_response_len\n",
            encoding="utf-8",
        )
    with csv_log_path.open("a", encoding="utf-8") as f:
        f.write(
            f"{step_idx},{total_steps},{loss:.6f},{reward_avg:.6f},"
            f"{avg_response_len:.6f}\n"
        )

In [51]:
torch.manual_seed(1)
train_rlvr_grpo(
    model=model,
    tokenizer=tokenizer,
    math_data=math_train,
    device=device,
    steps=50,
    num_rollouts=4,
    max_new_tokens=512,
    temperature=0.8,
    top_p=0.9,
    lr=1e-5,
    checkpoint_every=5,
)

[Step 1/50] loss=-0.0000 reward_avg=0.000 avg_resp_len=5.5
[Step 2/50] loss=-0.0000 reward_avg=0.000 avg_resp_len=6.8
[Step 3/50] loss=-0.7200 reward_avg=0.375 avg_resp_len=8.2
[Step 4/50] loss=2.8870 reward_avg=0.500 avg_resp_len=55.0
[Step 5/50] loss=-0.0000 reward_avg=0.000 avg_resp_len=164.2
Saved checkpoint to qwen3-0.6B-rlvr-grpo-step00005.pth
[Step 6/50] loss=-0.0000 reward_avg=0.000 avg_resp_len=117.8

KeyboardInterrupt. Saved checkpoint to qwen3-0.6B-rlvr-grpo-step00007-interrupt.pth


Qwen3Model(
  (tok_emb): Embedding(151936, 1024)
  (trf_blocks): ModuleList(
    (0-27): 28 x TransformerBlock(
      (att): GroupedQueryAttention(
        (W_query): Linear(in_features=1024, out_features=2048, bias=False)
        (W_key): Linear(in_features=1024, out_features=1024, bias=False)
        (W_value): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNorm()
        (k_norm): RMSNorm()
      )
      (ff): FeedForward(
        (fc1): Linear(in_features=1024, out_features=3072, bias=False)
        (fc2): Linear(in_features=1024, out_features=3072, bias=False)
        (fc3): Linear(in_features=3072, out_features=1024, bias=False)
      )
      (norm1): RMSNorm()
      (norm2): RMSNorm()
    )
  )
  (final_norm): RMSNorm()
  (out_head): Linear(in_features=1024, out_features=151936, bias=False)
)

Overall the loss fluctuates a lot, and thius is normal in RL training. Long term we want to see:

1) **REWARD AVG** to increase (model produces accurate responses)
2) The **REASONING ACCURACY** should improve (we evaluate this later)

# Loading and Evaluating saved model checkpoints

In [52]:
from reasoning_from_scratch.qwen3 import download_qwen3_grpo_checkpoints
download_qwen3_grpo_checkpoints(grpo_type="no_kl", step="00050")

qwen3-0.6B-rlvr-grpo-step00050.pth: 100% (1433 MiB / 1433 MiB)


In [None]:
!python  scripts/evaluate_math500.py --dataset_size 500 --which_model base --checkpoint_path qwen3-0.6B-rlvr-grpo-step00050.pth