In [4]:
from pathlib import Path
import torch
from reasoning_from_scratch.ch02 import (
        get_device
)
from reasoning_from_scratch.qwen3 import (
        download_qwen3_small,
        Qwen3Tokenizer,
        Qwen3Model,
        QWEN_CONFIG_06_B
)
from reasoning_from_scratch.qwen3 import KVCache


In [5]:
def load_model_and_tokenizer(which_model, device, use_compile, local_dir="qwen3"):

    if which_model == "base":
        download_qwen3_small(kind="base", tokenizer_only=False, out_dir=local_dir)
        tokenizer_path = Path(local_dir) / "tokenizer-base.json"
        model_path = Path(local_dir) / "qwen3-0.6B-base.pth"
        tokenizer = Qwen3Tokenizer(tokenizer_file_path=tokenizer_path)

    elif which_model == "reasoning":
        download_qwen3_small(kind="reasoning", tokenizer_only=False, out_dir=local_dir)
        tokenizer_path = Path(local_dir) / "tokenizer-reasoning.json"
        model_path = Path(local_dir) / "qwen3-0.6B-reasoning.pth"
        tokenizer = Qwen3Tokenizer(
                tokenizer_file_path=tokenizer_path,
                apply_chat_template=True,
                add_generation_prompt=True,
                add_thinking=True,
        )
    else:
        raise ValueError(f"Invalid choice: which_model={which_model}")
    
    model = Qwen3Model(QWEN_CONFIG_06_B)
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    if use_compile: #Optionally set to true to enable model compilation
        torch._dynamo.config.allow_unspec_int_on_nn_module = True
        model = torch.compile(model)
    return model, tokenizer


In [6]:
WHICH_MODEL = "base" #Uses the base model, similar to chapter 2, by default
device = get_device()
model, tokenizer = load_model_and_tokenizer(
        which_model=WHICH_MODEL,
        device=device,
        use_compile=False
)

Using NVIDIA CUDA GPU
✓ qwen3\qwen3-0.6B-base.pth already up-to-date


  model.load_state_dict(torch.load(model_path))


In [7]:
@torch.inference_mode()
def generate_text_basic_stream_cache(model, input_ids, max_new_tokens, eos_token_id=None):
    model.eval()

    cache = KVCache(n_layers=model.cfg['n_layers'])
    model.reset_kv_cache()

    out = model(input_ids, cache=cache)[:, -1]

    for _ in range(max_new_tokens):
        next_token = torch.argmax(out, dim=-1, keepdim=True)

        if (eos_token_id is not None
                and next_token.item() == eos_token_id):
            break

        yield next_token  # Yield each token as it's generated

        out = model(next_token, cache=cache)[:, -1]
        
    


In [8]:
prompt = ( #MATH PROBLEM
    r"If $a+b=3$ and $ab=\tfrac{13}{6}$, "
    r"what is the value of $a^2+b^2$?"
)

input_ids = torch.tensor(
    tokenizer.encode(prompt),
    device=device
).unsqueeze(0)

all_token_ids = []

for token in generate_text_basic_stream_cache(model, input_ids, max_new_tokens=2048, eos_token_id=tokenizer.eos_token_id):
    token_id = token.squeeze(0)
    decoded_id = tokenizer.decode([token_id])
    print(
        decoded_id,
        end='',
        flush=True
    )

    all_token_ids.append(token_id)

all_tokens = tokenizer.decode(all_token_ids)


 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

\[
a^2 + b^2 = (a + b)^2 - 2ab
\]

**Step 1:** Substitute the given values into the equation.

\[
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
\]

**Step 2:** Calculate \( (3)^2 \).

\[
(3)^2 = 9
\]

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

\[
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
\]

**Step 4:** Subtract the second result from the first.

\[
a^2 + b^2 = 9 - \frac{13}{3}
\]

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

\[
9 = \frac{27}{3}
\]

\[
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
\]

**Final Answer:**

\[
\boxed{\dfrac{14}{3}}
\]

In [9]:
from IPython.display import Latex, display
display(Latex(all_tokens))

<IPython.core.display.Latex object>

Now we can prepare the wrapper function

In [10]:
def generate_text_stream_concat(model, tokenizer, prompt, device, max_new_tokens, verbose=False):

    input_ids = torch.tensor(
        tokenizer.encode(prompt),
        device=device
    ).unsqueeze(0)

    all_token_ids = []

    for token in generate_text_basic_stream_cache(model, input_ids, max_new_tokens=max_new_tokens, eos_token_id=tokenizer.eos_token_id):
        token_id = token.squeeze(0)
        all_token_ids.append(token_id)
        

        if verbose:
            decoded_id = tokenizer.decode([token_id])
            print(
                decoded_id,
                end='',
                flush=True
            )

    return tokenizer.decode(all_token_ids)


In [11]:
model_answer = generate_text_stream_concat(model, tokenizer, prompt, device, max_new_tokens=2048, verbose=True)

 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

\[
a^2 + b^2 = (a + b)^2 - 2ab
\]

**Step 1:** Substitute the given values into the equation.

\[
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
\]

**Step 2:** Calculate \( (3)^2 \).

\[
(3)^2 = 9
\]

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

\[
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
\]

**Step 4:** Subtract the second result from the first.

\[
a^2 + b^2 = 9 - \frac{13}{3}
\]

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

\[
9 = \frac{27}{3}
\]

\[
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
\]

**Final Answer:**

\[
\boxed{\dfrac{14}{3}}
\]

In [12]:
def get_last_boxed(text):
    boxed_start_idx = text.rfind(r"\boxed") 
    if boxed_start_idx == -1:
        return None
    current_idx = boxed_start_idx + len(r"\boxed") 

    while current_idx < len(text) and text[current_idx].isspace():
        current_idx += 1
    
    if current_idx >= len(text) or text[current_idx] != "{":
        return None
    
    current_idx += 1
    brace_depth = 1
    content_start_idx = current_idx
    
    while current_idx < len(text) and brace_depth > 0:
        char = text[current_idx]
        if char == "{":
            brace_depth += 1
        elif char == "}":
            brace_depth -= 1
        current_idx += 1

    if brace_depth != 0: 
        return None
    
    return text[content_start_idx:current_idx-1] 

In [13]:
from IPython.display import Math

extracted_answer = get_last_boxed(model_answer)
display(Math(extracted_answer))

<IPython.core.display.Math object>

We can also handle cases in which the model fails to format a correct answer

In [14]:
import re
RE_NUMBER = re.compile( 
        r"-?(?:\d+/\d+|\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)"
    )
def extract_final_candidate(text, fallback="number_then_full"):
    result = "" 

    if text: 
        boxed = get_last_boxed(text.strip())

        if boxed:
            result = boxed.strip().strip("$ ")
    
        elif fallback in ("number_then_full", "number_only"):
            m = RE_NUMBER.findall(text)

            if m:
                result = m[-1] 
            elif fallback == "number_then_full":
                result = text 

    return result


In [15]:
final_candidate = extract_final_candidate(r"\boxed{ 14/3. }")
display(Math(final_candidate))

<IPython.core.display.Math object>

In [16]:
final_candidate = extract_final_candidate("abc < > 14/3 abc") #last number is 14/3
display(Math(final_candidate))

<IPython.core.display.Math object>

Now we can NORMALIZE the extracted answer

In [17]:
import re

LATEX_FIXES = [  # A
    (r"\\left\s*", ""),
    (r"\\right\s*", ""),
    (r"\\,|\\!|\\;|\\:", ""),
    (r"\\cdot", "*"),
    (r"\u00B7|\u00D7", "*"),
    (r"\\\^\\circ", ""),
    (r"\\dfrac", r"\\frac"),
    (r"\\tfrac", r"\\frac"),
    (r"°", ""),
]
RE_SPECIAL = re.compile(r"<\|[^>]+?\|>")  # B
SUPERSCRIPT_MAP = {
    "⁰": "0", "¹": "1", "²": "2", "³": "3", "⁴": "4",  # C
    "⁵": "5", "⁶": "6", "⁷": "7", "⁸": "8", "⁹": "9",  # C
    "⁺": "+", "⁻": "-", "⁽": "(", "⁾": ")",  # C
}

def normalize_text(text):
    if not text:
        return ""
    text = RE_SPECIAL.sub("", text).strip()
    # D
    match = re.match(r"^[A-Za-z]\s*[.:]\s*(.+)$", text)
    if match:
        text = match.group(1)
    text = re.sub(r"\^\s*\{\s*\\circ\s*\}", "", text)  # D
    text = re.sub(r"\^\s*\\circ", "", text)  # E
    text = text.replace("°", "")  # E
    match = re.match(r"^\\text\{(?P<x>.+?)\}$", text)  # F
    if match:
        text = match.group("x")
    text = re.sub(r"\\\(|\\\)|\\\[|\\\]", "", text)  # G
    for pat, rep in LATEX_FIXES:  # H
        text = re.sub(pat, rep, text)
    
    def convert_superscripts(s, base=None):
        converted = "".join(
            SUPERSCRIPT_MAP[ch] if ch in SUPERSCRIPT_MAP else ch
            for ch in s
        )
        if base is None:
            return converted
        return f"{base}**{converted}"
    
    text = re.sub(
        r"([0-9A-Za-z\)\]\}])([⁰¹²³⁴⁵⁶⁷⁸⁹⁺⁻]+)",
        lambda m: convert_superscripts(m.group(2), base=m.group(1)),
        text,
    )
    text = convert_superscripts(text)
    # I
    text = text.replace("\\%", "%").replace("$", "").replace("%", "")
    text = re.sub(
        r"\\sqrt\s*\{([^}]*)\}",
        lambda match: f"sqrt({match.group(1)})",
        text,
    )
    text = re.sub(
        r"\\sqrt\s+([^\\\s{}]+)",
        lambda match: f"sqrt({match.group(1)})",
        text,
    )
    # J
    text = re.sub(
        r"\\frac\s*\{([^{}]+)\}\s*\{([^{}]+)\}",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )
    text = re.sub(
        r"\\frac\s+([^\s{}]+)\s+([^\s{}]+)",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )
    # K
    text = text.replace("^", "**")
    text = re.sub(
        r"(?<=\d)\s+(\d+/\d+)",
        lambda match: "+" + match.group(1),
        text,
    )
    # L
    text = re.sub(
        r"(?<=\d),(?=\d\d\d(\D|$))",
        "",
        text,
    )
    return text.replace("{", "").replace("}", "").strip().lower()


In [18]:
print(normalize_text(extract_final_candidate(model_answer)))

(14)/(3)


In [19]:
print(normalize_text(r"\text{\[\frac{14}{3}\]}"))


(14)/(3)


We can now verify the mathematical equivalence between the extracted answer from LLm and a GT

In [20]:
from sympy.parsing import sympy_parser as spp
from sympy.core.sympify import SympifyError
from sympy.polys.polyerrors import PolynomialError
from tokenize import TokenError


In [21]:
def sympy_parser(expr):
    try:
        return spp.parse_expr(expr, transformations=(

            *spp.standard_transformations, #like handling parenthesis invariant symbols
            spp.implicit_multiplication_application, #allow omitted mul symbols, like 2y == 2*y

            ),
        evaluate=True,
        )
    except (SympifyError, SyntaxError, TypeError, AttributeError,
            IndexError, TokenError, ValueError, PolynomialError):
        return None




In [22]:
print(sympy_parser(
    normalize_text(
        extract_final_candidate(
            model_answer))))

14/3


In [23]:
print(sympy_parser("28/6")) #normalized by sympy

14/3


We can noe build the equality function

In [24]:
from sympy import simplify

def equality_check(expr_gt, expr_pred):

    if expr_gt == expr_pred:
        return True
    
    gt, pred = sympy_parser(expr_gt), sympy_parser(expr_pred)

    if gt is not None and pred is not None:

        try:
            return simplify(gt - pred) == 0 #for example, 14/3 and 28/6
        
        except(SympifyError, TypeError):
            return False
    
    return False

In [25]:
print(equality_check(
    normalize_text("13/4."),
    normalize_text(r"(13)/(4)")
))

True


In [26]:
print(equality_check(
    normalize_text("0.5"),
    normalize_text(r"(1)/(2)")
))

True


In [27]:
print(equality_check(
    normalize_text("14/3"),
    normalize_text("15/3")
))

False


In [28]:
print(equality_check(
    normalize_text("(14/3, 2/3)"),
    normalize_text("(14/3, 4/6)")
))
#it cannot handle tuples as of now

False


Now we implement an helper function that can handle the tuple-like expressions by isolating each term

In [29]:
def split_into_parts(text):
    result = [text]
    if (text):

        if (
            len(text) >= 2 and
            text[0] in "([" and text[-1] in ")]"
            and "," in text[1:-1]
        ):
            items = [p.strip() for p in text[1:-1].split(",")]
            if all(items): #all not None
                result = items
    else:
        result = []
    
    return result        
    

In [30]:
split_into_parts(normalize_text(r"(14/3, 2/3)"))

['14/3', '2/3']

Now we can implement a function that grades the result w.r.t. the GT that generalize the equality_check function 

In [31]:
def grade_answer(pred_text, gt_text) -> bool: #or True or False
    result = False
    if pred_text is not None and gt_text is not None:
        gt_parts = split_into_parts(normalize_text(gt_text))
        pred_parts = split_into_parts(normalize_text(pred_text))

        if (gt_parts and pred_parts 
            and len(gt_parts) == len(pred_parts)):
            result = all(equality_check(gt, pred) for gt, pred in zip(gt_parts, pred_parts))
    
    return result



In [32]:
grade_answer("14/3", r"\frac{14}{3}")

True

In [33]:
#with TUPLES
grade_answer(r"(14/3, 2/3)", "(14/3, 4/6)")

True

We can check with more tests

In [34]:
tests = [ #A
("check_1", "3/4", r"\frac{3}{4}", True),
("check_2", "(3)/(4)", r"3/4", True),
("check_3", r"\frac{\sqrt{8}}{2}", "sqrt(2)", True),
("check_4", r"\( \frac{1}{2} + \frac{1}{6} \)", "2/3", True),
("check_5", "(1, 2)", r"(1,2)", True),
("check_6", "(2, 1)", "(1, 2)", False),
("check_7", "(1, 2, 3)", "(1, 2)", False),
("check_8", "0.5", "1/2", True),
("check_9", "0.3333333333", "1/3", False),
("check_10", "1,234/2", "617", True),
("check_11", r"\text{2/3}", "2/3", True),
("check_12", "50%", "1/2", False),
("check_13", r"2\cdot 3/4", "3/2", True),
("check_14", r"90^\circ", "90", True),
("check_15", r"\left(\frac{3}{4}\right)", "3/4", True),
("check_16", r"2²", "2**2", True),
]

def run_demos_table(tests):
    header = ("Test", "Expect", "Got", "Status")
    rows = []
    for name, pred, gtruth, expect in tests:
        got = grade_answer(pred, gtruth) #B
        status = "PASS" if got == expect else "FAIL"
        rows.append((name, str(expect), str(got), status))

    data = [header] + rows
    col_widths = [ #C
        max(len(row[i]) for row in data)
        for i in range(len(header))
    ]

    for row in data: #D
        line = " | ".join(
            row[i].ljust(col_widths[i])
            for i in range(len(header))
        )

        print(line)
    passed = sum(r[3] == "PASS" for r in rows) #E
    print(f"\nPassed {passed}/{len(rows)}")

In [35]:
run_demos_table(tests)

Test     | Expect | Got   | Status
check_1  | True   | True  | PASS  
check_2  | True   | True  | PASS  
check_3  | True   | True  | PASS  
check_4  | True   | True  | PASS  
check_5  | True   | True  | PASS  
check_6  | False  | False | PASS  
check_7  | False  | False | PASS  
check_8  | True   | True  | PASS  
check_9  | False  | False | PASS  
check_10 | True   | True  | PASS  
check_11 | True   | True  | PASS  
check_12 | False  | False | PASS  
check_13 | True   | True  | PASS  
check_14 | True   | True  | PASS  
check_15 | True   | True  | PASS  
check_16 | True   | True  | PASS  

Passed 16/16


In [36]:
more_tests = [
    # Different bracket types
    ("check_17", "[1, 2]", "(1, 2)", True),

    # Scientific notation
    ("check_18", "1e-3", "0.001", True),

    # Algebraic simplification with caret exponent
    ("check_19", "(-3)^2", "9", True),

    # Unicode minus (U+2212) vs ASCII hyphen-minus
    ("check_20", "−1", "-1", True), 

]

run_demos_table(more_tests)

Test     | Expect | Got   | Status
check_17 | True   | True  | PASS  
check_18 | True   | True  | PASS  
check_19 | True   | True  | PASS  
check_20 | True   | False | FAIL  

Passed 3/4


Our function doesn't handle ASCII special characters, to address this, we could generalize the normalize_text function

In [37]:
extra_tests_2 = [
    ('check_21', extract_final_candidate('Text around 3 I think'), '3', True)
]
run_demos_table(extra_tests_2)

Test     | Expect | Got  | Status
check_21 | True   | True | PASS  

Passed 1/1


# Loading the evaluation dataset

We will use the MATH-500 dataset (https://huggingface.co/datasets/HuggingFaceH4/MATH-500)

In [38]:
import json
import requests

In [39]:
def load_math500_test(local_path='math500_test.json', save_copy=True):
    local_path = Path(local_path)
    url = (
        "https://raw.githubusercontent.com/rasbt/reasoning-from-scratch/"
        "main/ch03/01_main-chapter-code/math500_test.json"
    )

    if local_path.exists():
        with open(local_path, 'r', encoding='utf-8') as f:
            data = json.load(f) 
    else:
        r = requests.get(url, timeout=30)
        r.raise_for_status
        data=r.json()

        if save_copy:
            with open(local_path, 'w', encoding='utf-8') as f:
                json.dump(data, f, indent=2)

    return data 


In [40]:
math_data = load_math500_test()
print("Number of entries:", len(math_data))

Number of entries: 500


In [41]:
#OPTIONALLY we could have done it with huggingface

# from datasets import load_dataset
# dset = load_dataset("HuggingFaceH4/MATH-500", split="test")

let's print the first entry of the dataset

In [42]:
from pprint import pprint

In [43]:
pprint(math_data[0])

{'answer': '\\left( 3, \\frac{\\pi}{2} \\right)',
 'level': 2,
 'problem': 'Convert the point $(0,3)$ in rectangular coordinates to polar '
            'coordinates.  Enter your answer in the form $(r,\\theta),$ where '
            '$r > 0$ and $0 \\le \\theta < 2 \\pi.$',
 'solution': 'We have that $r = \\sqrt{0^2 + 3^2} = 3.$  Also, if we draw the '
             'line connecting the origin and $(0,3),$ this line makes an angle '
             'of $\\frac{\\pi}{2}$ with the positive $x$-axis.\n'
             '\n'
             '[asy]\n'
             'unitsize(0.8 cm);\n'
             '\n'
             'draw((-0.5,0)--(3.5,0));\n'
             'draw((0,-0.5)--(0,3.5));\n'
             'draw(arc((0,0),3,0,90),red,Arrow(6));\n'
             '\n'
             'dot((0,3), red);\n'
             'label("$(0,3)$", (0,3), W);\n'
             'dot((3,0), red);\n'
             '[/asy]\n'
             '\n'
             'Therefore, the polar coordinates are $\\boxed{\\left( 3, '
             '\\frac

We need that the model outputs te final answer as a boxed string. To increase the likelihood that this happens, we can feed into the model a System Prompt that encourage this format.

In [46]:
def render_prompt(prompt):
    template = """ 
You are an helpfu and powerful assistant.
Answer the question and write the final result on a new line as:
\\boxed{{ANSWER}}

Question:
{PROMPT}

Answer:
    """.format(PROMPT=prompt)

    return template

In [47]:
prompt = (
    r"If $a+b=3$ and $ab=\tfrac{13}{6}$, "
    r"what is the value of $a^2+b^2$?"
)
prompt_fmt = render_prompt(prompt)
print(prompt_fmt)

 
You are an helpfu and powerful assistant.
Answer the question and write the final result on a new line as:
\boxed{ANSWER}

Question:
If $a+b=3$ and $ab=\tfrac{13}{6}$, what is the value of $a^2+b^2$?

Answer:
    


In [48]:
generated_text = generate_text_stream_concat(model, tokenizer, prompt_fmt, device, max_new_tokens=2048, verbose=True)

 \boxed{13}

As we can see, this prmpt encourages the model to outputs short anwers, prioritizing speed against correctness. In fact, without this formatting, the output of the mdoel was longer and slower but correct (**14/3**), while now it is quicker but wrong (**13**).

On Math-500 dataset, it was proven that just changing, in the previous prompt, the qeury indicator word from 'Qeustion' to 'Problem' increases ther BASE model performance of nearly 20%, while the REASONING model drops its performance of nearly 30%.

In [52]:
def mini_eval_demo(model, tokenizer, device):
    ex = {
        'problem': 'Comput 1/2 + 1/6',
        'answer': '2/3'
    }
    prompt = render_prompt(ex['problem'])
    gen_text = generate_text_stream_concat(model, tokenizer, prompt, device, max_new_tokens=64, verbose=True)
    pred_answer = extract_final_candidate(gen_text)
    is_correct = grade_answer(pred_answer, ex['answer'])

    print(f"\nDevice: {device}")
    print(f"Prediction: {pred_answer}")
    print(f"Ground truth: {ex['answer']}")
    print(f"Correct: {is_correct}")

In [53]:
mini_eval_demo(model, tokenizer, device)

 \boxed{1/3}
Device: cuda
Prediction: 1/3
Ground truth: 2/3
Correct: False


End-to-end evaluation on math-500

In [54]:
import time

In [55]:
def eta_progress_message(processed, total, start_time, show_eta=False, label="Progress"):
    progress = f'{label}: {processed}/{total}'

    if not show_eta or processed <= 0:
        return progress
    
    elapsed = time.time() - start_time
    if elapsed <= 0 :
        return progress
    
    remaining = max(total - processed, 0)

    if processed:
        avg_time = elapsed / processed
        eta_seconds = avg_time * remaining
    else:
        eta_seconds = 0
    
    eta_seconds = max(round(eta_seconds), 0)
    minutes, rem_seconds = divmod(eta_seconds, 60) 
    hours, minutes, = divmod(minutes, 60)
    if hours:
        eta = f"{hours}h {minutes:02d}m {rem_seconds:02d}s"
    elif minutes:
        eta = f"{minutes:02d}m {rem_seconds:02d}s"
    else:
        eta = f"{rem_seconds:02d}s"
        
    return f"{progress} | ETA: {eta}"

In [61]:
def evaluate_math500_stream(model, model_type, tokenizer, device, math_data, out_path=None, max_new_tokens=512, verbose=False):
    if out_path is None:
        dev_name = str(device).replace(':', '-')
        out_path = Path(f'math500-{dev_name}-{model_type}.jsonl') #jsonl is a file format in which we have a JSON entry for each row
    
    num_examples = len(math_data)
    num_correct = 0
    start_time = time.time()

    with open(out_path, 'w', encoding="utf-8") as f:
        for i, row in enumerate(math_data, start=1):
            prompt = render_prompt(row['problem'])
            gen_text = generate_text_stream_concat(model, tokenizer, prompt, device, max_new_tokens=max_new_tokens, verbose=verbose)

            extracted = extract_final_candidate(gen_text)
            is_correct = grade_answer(extracted, row['answer'])

            num_correct += int(is_correct)

            record = { 
                "index": i,
                "problem": row["problem"],
                "gtruth_answer": row["answer"],
                "generated_text": gen_text,
                "extracted": extracted,
                "correct": bool(is_correct),
            }
            f.write(json.dumps(record, ensure_ascii=False) + "\n")

            progress_message = eta_progress_message(processed=i, total=num_examples, start_time=start_time, show_eta=True, label='MATH-500')

            print(progress_message, end="\r", flush=True)

            if verbose: 
                print(
                    f"\n\n{'='*50}\n{progress_message}\n"
                    f"{'='*50}\nExtracted: {extracted}\n"
                    f"Expected: {row['answer']}\n"
                    f"Correct so far: {num_correct}\n{'-'*50}"
                )

    seconds_elapsed = time.time() - start_time
    acc = num_correct / num_examples if num_examples else 0.0
    print(f"\nAccuracy: {acc*100:.1f}% ({num_correct}/{num_examples})")
    print(f"Total time: {seconds_elapsed/60:.1f} min")
    print(f"Logs written to: {out_path}")
    return num_correct, num_examples, acc


In [62]:
print("Model:", WHICH_MODEL)
print(f'Device: {device}')
num_correct, num_examples, acc = evaluate_math500_stream(
    model, WHICH_MODEL, tokenizer, device,
    math_data=math_data[:10], #only evaluates the first 10 examples
    max_new_tokens=2048,
    verbose=False 
)


Model: base
Device: cuda
MATH-500: 10/10 | ETA: 00s
Accuracy: 20.0% (2/10)
Total time: 0.0 min
Logs written to: math500-cuda-base.jsonl


The poor performance is due to the absence of REASONING in the base model, which is crucial to solve math problems. We can see the switch solving the same tasks with the REASONING model:

In [65]:
model_reasoning, tokenizer_reasoning = load_model_and_tokenizer(
    which_model='reasoning',
    device=device,
    use_compile=False, local_dir='qwen3_reasoning'
)

✓ qwen3_reasoning\qwen3-0.6B-reasoning.pth already up-to-date
✓ qwen3_reasoning\tokenizer-reasoning.json already up-to-date


  model.load_state_dict(torch.load(model_path))


In [66]:
print("Model:", 'reasoning')
print(f'Device: {device}')
num_correct, num_examples, acc = evaluate_math500_stream(
    model_reasoning, 'reasoning', tokenizer_reasoning, device,
    math_data=math_data[:10], #only evaluates the first 10 examples
    max_new_tokens=2048,
    verbose=False 
)

Model: reasoning
Device: cuda
MATH-500: 10/10 | ETA: 00s25s
Accuracy: 90.0% (9/10)
Total time: 5.1 min
Logs written to: math500-cuda-reasoning.jsonl


The reasoning model performs way better (90% vs 20%) but is more compute heavy, as deomnstarted by the execution time

In [90]:
def compute_average_response_length(report_path, tokenizer):
    if not report_path:
        return None
    
    path = Path(report_path)
    tot_len = 0
    tot_els = 0

    with open(path, 'r', encoding='utf-8') as f:
        for row in f:
            json_row = json.loads(row) #converts a string, while load() a file
            generated_text = tokenizer.encode(json_row['generated_text'])
            tot_len += len(generated_text)
            tot_els += 1
    
    return float(tot_len/(1.0*tot_els))
    

In [100]:
reasoning_tokenizer_path = Path('qwen3_reasoning') / 'tokenizer-reasoning.json' 
base_tokenizer_path = Path('qwen3') / 'tokenizer-base.json' 

reasoning_tokenizer = Qwen3Tokenizer(reasoning_tokenizer_path, apply_chat_template=True)
base_tokenizer = Qwen3Tokenizer(base_tokenizer_path)

In [92]:
print(f'Average length of BASE model response: {compute_average_response_length('math500-cuda-base.jsonl', tokenizer=base_tokenizer)} tokens')
print(f'Average length of REASONING model response: {compute_average_response_length('math500-cuda-reasoning.jsonl', tokenizer=reasoning_tokenizer)} tokens')

Average length of BASE model response: 6.8 tokens
Average length of REASONING model response: 891.4 tokens


We see a huge difference!, now let's try with the chat template and the keyword 'Problem' instead of 'Question'

In [103]:
def render_prompt(prompt):
    template = """ 
You are an helpfu and powerful assistant.
Answer the question and write the final result on a new line as:
\\boxed{{ANSWER}}

Problem:
{PROMPT}

Answer:
    """.format(PROMPT=prompt)

    return template

In [104]:
print("Model:", WHICH_MODEL)
print(f'Device: {device}')
num_correct, num_examples, acc = evaluate_math500_stream(
    model, WHICH_MODEL, base_tokenizer, device,
    math_data=math_data[:10], #only evaluates the first 10 examples
    max_new_tokens=2048,
    verbose=False 
)

Model: base
Device: cuda
MATH-500: 10/10 | ETA: 00s
Accuracy: 20.0% (2/10)
Total time: 0.1 min
Logs written to: math500-cuda-base.jsonl


In [None]:
print(f'Average length of BASE model response: {compute_average_response_length('math500-cuda-base.jsonl', tokenizer=base_tokenizer)} tokens')
 #more tokens generated (same performance on this small dataset, but overall would increase from 15% to 31%)

Average length of BASE model response: 10.1 tokens
