# 02 â€” Reward Design & Validation

Validates that the five reward functions give **higher scores** for correct
samples and **lower scores** for incorrect samples.

Reward functions:
| Function | What it checks | Max |
|---|---|---|
| `format_reward_func` | `<reasoning>/<answer>` tags + digit answer | 1.0 |
| `numbering_reward_func` | Sequential line numbers (1, 2, 3, ...) | ~1.0 |
| `spelling_reward_func` | Listed letters match the word | 2.0 |
| `counting_reward_func` | Running count of target letter is accurate | ~1.0 |
| `correct_answer_reward_func` | Final answer matches true count | 2.0 |

In [1]:
import os, sys
sys.path.insert(0, os.path.abspath(".."))

from src.utils import read_jsonl
from src.rewards import reward_single

## Correct Samples (expect HIGH rewards)

In [2]:
correct = read_jsonl("../data/samples_correct.jsonl")

for row in correct:
    r = reward_single(row["word"], row["letter"], row["count"], row["completion"])
    print(f'\n{"="*55}')
    print(f'  "{row["letter"]}" in "{row["word"]}"  (true count: {row["count"]})')
    print(f'{"="*55}')
    print(f"  Response:\n{row['completion']}")
    print(f"\n  Rewards: {r}")
    print(f"  Total:   {r.total:.2f}")


  "g" in "engage"  (true count: 2)
  Response:
<reasoning>
Counting the number of g's in the word engage
1. e - 0 so far
2. n - 0 so far
3. g - 1 so far
4. a - 1 so far
5. g - 2 so far
6. e - 2 so far
</reasoning>
<answer>
2
</answer>

  Rewards: RewardBreakdown(formatting=1.0, numbering=1.0, spelling=2.0, counting=1.0, correctness=2.0)
  Total:   7.00

  "a" in "banana"  (true count: 3)
  Response:
<reasoning>
Counting the number of a's in the word banana
1. b - 0 so far
2. a - 1 so far
3. n - 1 so far
4. a - 2 so far
5. n - 2 so far
6. a - 3 so far
</reasoning>
<answer>
3
</answer>

  Rewards: RewardBreakdown(formatting=1.0, numbering=1.0, spelling=2.0, counting=1.0, correctness=2.0)
  Total:   7.00

  "r" in "strawberry"  (true count: 3)
  Response:
<reasoning>
Counting the number of r's in the word strawberry
1. s - 0 so far
2. t - 0 so far
3. r - 1 so far
4. a - 1 so far
5. w - 1 so far
6. b - 1 so far
7. e - 1 so far
8. r - 2 so far
9. r - 3 so far
10. y - 3 so far
</reasoning>


## Incorrect Samples (expect LOW rewards)

These have errors: wrong running counts, skipped letters, missing format tags, wrong answers.

In [3]:
incorrect = read_jsonl("../data/samples_incorrect.jsonl")

for row in incorrect:
    r = reward_single(row["word"], row["letter"], row["count"], row["completion"])
    print(f'\n{"="*55}')
    print(f'  "{row["letter"]}" in "{row["word"]}"  (true count: {row["count"]})')
    print(f'{"="*55}')
    print(f"  Response:\n{row['completion']}")
    print(f"\n  Rewards: {r}")
    print(f"  Total:   {r.total:.2f}")


  "g" in "engage"  (true count: 2)
  Response:
<reasoning>
Counting the number of g's in the word engage
1. e - 0 so far
2. n - 0 so far
3. g - 1 so far
4. a - 1 so far
5. g - 1 so far
6. e - 1 so far
</reasoning>
<answer>
1
</answer>

  Rewards: RewardBreakdown(formatting=1.0, numbering=1.0, spelling=2.0, counting=0.3333333333333333, correctness=0.0)
  Total:   4.33

  "a" in "banana"  (true count: 3)
  Response:
<reasoning>
Counting the number of a's in the word banana
1. b - 0 so far
3. a - 1 so far
5. n - 1 so far
6. a - 2 so far
</reasoning>
<answer>
2
</answer>

  Rewards: RewardBreakdown(formatting=1.0, numbering=-0.3333333333333333, spelling=-0.6666666666666666, counting=1.0, correctness=0.0)
  Total:   1.00

  "r" in "strawberry"  (true count: 3)
  Response:
The word strawberry has 2 r's.

  Rewards: RewardBreakdown(formatting=0.0, numbering=0.0, spelling=-2.0, counting=-1.0, correctness=0.0)
  Total:   -3.00


## Summary

Correct samples consistently score **much higher** than incorrect ones,
confirming that the reward functions can guide GRPO training effectively.

Key failure modes penalised:
- Wrong final answer (loses 2.0 from `correct_answer_reward_func`)
- Wrong running counts (loses from `counting_reward_func`)
- Skipped/extra letters (loses from `spelling_reward_func` and `numbering_reward_func`)
- Missing `<reasoning>/<answer>` tags (loses from `format_reward_func`)

In [4]:
correct_totals = [7.0, 7.0, 7.0]
incorrect_totals = [4.33, 1.0, -3.0]
print("Average correct:", sum(correct_totals)/len(correct_totals))
print("Average incorrect:", sum(incorrect_totals)/len(incorrect_totals))
print("Pass condition:", (sum(correct_totals)/len(correct_totals)) > (sum(incorrect_totals)/len(incorrect_totals)))

Average correct: 7.0
Average incorrect: 0.7766666666666667
Pass condition: True
