In [39]:
from pathlib import Path
import re
import torch
import torchinfo
import json
from urllib.request import urlopen
from utils import download_qwen3_small, Qwen3Tokenizer
from qwen3 import Qwen3Model, QWEN_CONFIG_06_B, KVCache
from typing import List, Tuple, Optional, Generator, cast, Literal
from IPython.display import Latex, Math, display
from sympy.parsing import sympy_parser as spp
from sympy.core.sympify import SympifyError
from tokenize import TokenError
from sympy import Expr, simplify
from pprint import pprint

# 3.2 Loading a pre-trained model to generate text

In [2]:
WHICH_MODEL = "base"
USE_COMPILE = False
RE_NUMBER = re.compile(pattern=r"-?(?:\d+/\d+|\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)")

In [3]:
def set_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device(device="cuda")
    elif torch.backends.mps.is_available():
        return torch.device(device="mps")
    else:
        return torch.device(device="cpu")


device = set_device()
print(f"Using device: {device}")

Using device: mps


In [4]:
if WHICH_MODEL == "base":
    download_qwen3_small(kind="base", tokenizer_only=False, out_dir="qwen3")
    tokenizer_path = Path("qwen3") / "tokenizer-base.json"
    model_path = Path("qwen3") / "qwen3-0.6B-base.pth"
    tokenizer = Qwen3Tokenizer(tokenizer_file_path=tokenizer_path)
 
elif WHICH_MODEL == "reasoning":
    download_qwen3_small(kind="reasoning", tokenizer_only=False, out_dir="qwen3")
    tokenizer_path = Path("qwen3") / "tokenizer-reasoning.json"
    model_path = Path("qwen3") / "qwen3-0.6B-reasoning.pth"
    tokenizer = Qwen3Tokenizer(
        tokenizer_file_path=tokenizer_path,
        apply_chat_template=True,
        add_generation_prompt=True,
        add_thinking=True,
    )

else:
    raise ValueError(f"Invalid choice: WHICH_MODEL={WHICH_MODEL}")

qwen3-0.6B-base.pth: 100% (1433 MiB / 1433 MiB)
tokenizer-base.json: 100% (6 MiB / 6 MiB)


In [5]:
text = "Hello, how are you today?"

ids = tokenizer.encode(text)
input_ids = torch.tensor(ids, dtype=torch.long).unsqueeze(0)
input_ids = input_ids.to(device)

In [6]:
model = Qwen3Model(cfg=QWEN_CONFIG_06_B)
model.load_state_dict(torch.load(f=model_path))
model.to(device)

torchinfo.summary(
    model=model,
    input_data=input_ids,
    verbose=0,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"]
)

Layer (type (var_name))                       Input Shape          Output Shape         Param #              Trainable
Qwen3Model (Qwen3Model)                       [1, 7]               [1, 7, 151936]       --                   True
├─Embedding (tok_emb)                         [1, 7]               [1, 7, 1024]         155,582,464          True
├─ModuleList (trf_blocks)                     --                   --                   --                   True
│    └─TransformerBlock (0)                   [1, 7, 1024]         [1, 7, 1024]         --                   True
│    │    └─RMSNorm (norm1)                   [1, 7, 1024]         [1, 7, 1024]         1,024                True
│    │    └─GroupedQueryAttention (att)       [1, 7, 1024]         [1, 7, 1024]         6,291,712            True
│    │    └─RMSNorm (norm2)                   [1, 7, 1024]         [1, 7, 1024]         1,024                True
│    │    └─FeedForward (ff)                  [1, 7, 1024]         [1, 7, 1024]    

In [7]:
if USE_COMPILE:
  torch._dynamo.config.allow_unspec_int_on_nn_module = True
  model = cast(Qwen3Model, torch.compile(model))

In [8]:
@torch.inference_mode()
def generate_text_basic_stream_cache(
    model: Qwen3Model,
    token_ids: torch.Tensor,
    max_new_tokens: int,
    eos_token_id: Optional[int] = None
) -> Generator[torch.Tensor, None, None]:
    
    model.eval()
    cache = KVCache(n_layers=model.cfg["n_layers"])
    model.reset_kv_cache()
 
    out = model(token_ids, cache=cache)[:, -1]
    for _ in range(max_new_tokens):
        next_token = torch.argmax(out, dim=-1, keepdim=True)
 
        if (eos_token_id is not None
                and torch.all(next_token == eos_token_id)):
            break
 
        yield next_token
        # token_ids = torch.cat([token_ids, next_token], dim=1)
        out = model(next_token, cache=cache)[:, -1]

In [9]:
prompt: str = (
    r"If $a+b=3$ and $ab=\tfrac{13}{6}$, "
    r"what is the value of $a^2+b^2$?"
)

In [10]:
input_token_ids_tensor = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)

In [11]:
all_token_ids = []
 
for token in generate_text_basic_stream_cache(
    model=model,
    token_ids=input_token_ids_tensor,
    max_new_tokens=2048,
    eos_token_id=tokenizer.eos_token_id
):
    token_id = token.squeeze(0)
    decoded_id = tokenizer.decode(token_id.tolist())
    print(
        decoded_id,   
        end="",
        flush=True
    )
    all_token_ids.append(token_id)

 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

\[
a^2 + b^2 = (a + b)^2 - 2ab
\]

**Step 1:** Substitute the given values into the equation.

\[
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
\]

**Step 2:** Calculate \( (3)^2 \).

\[
(3)^2 = 9
\]

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

\[
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
\]

**Step 4:** Subtract the second result from the first.

\[
a^2 + b^2 = 9 - \frac{13}{3}
\]

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

\[
9 = \frac{27}{3}
\]

\[
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
\]

**Final Answer:**

\[
\boxed{\dfrac{14}{3}}
\]

In [12]:
all_tokens = tokenizer.decode(all_token_ids)

display(Latex(all_tokens))

<IPython.core.display.Latex object>

# 3.3 Implementing a wrapper for easier text generation

In [13]:
def generate_text_stream_concat(
    model: Qwen3Model, tokenizer: Qwen3Tokenizer, prompt: str, 
    device: torch.device, max_new_tokens: int, verbose: bool=False) -> str:
    
    input_ids = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)
 
    generated_ids: List[int] = []
    for token in generate_text_basic_stream_cache(
        model=model,
        token_ids=input_ids,
        max_new_tokens=max_new_tokens,
        eos_token_id=tokenizer.eos_token_id,
    ):
        next_token_id_tensor = token.squeeze(0)
        next_token_id = cast(typ=int, val=next_token_id_tensor.item())
        generated_ids.append(next_token_id)
 
        
        if verbose:
            print(
                tokenizer.decode(token_ids=next_token_id_tensor.tolist()),
                end="",
                flush=True
            )
    
    return tokenizer.decode(token_ids=generated_ids)

In [14]:
generated_text = generate_text_stream_concat(
    model, tokenizer, prompt, device,
    max_new_tokens=2048,
    verbose=True
)

 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

\[
a^2 + b^2 = (a + b)^2 - 2ab
\]

**Step 1:** Substitute the given values into the equation.

\[
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
\]

**Step 2:** Calculate \( (3)^2 \).

\[
(3)^2 = 9
\]

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

\[
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
\]

**Step 4:** Subtract the second result from the first.

\[
a^2 + b^2 = 9 - \frac{13}{3}
\]

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

\[
9 = \frac{27}{3}
\]

\[
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
\]

**Final Answer:**

\[
\boxed{\dfrac{14}{3}}
\]

# 3.4 Extracting the final answer box

In [15]:
model_answer = (
r"""... some explanation...
**Final Answer:**
 
\[
\boxed{\dfrac{14}{3}}
\]
""")

In [16]:
def get_last_boxed(text: str) -> Optional[str]:
    boxed_start_idx = text.rfind(r"\boxed")
    if boxed_start_idx == -1:
        return None
 
    current_idx = boxed_start_idx + len(r"\boxed")
 

    while current_idx < len(text) and text[current_idx].isspace():
        current_idx += 1
 

    if current_idx >= len(text) or text[current_idx] != "{":
        return None
 
    current_idx += 1
    brace_depth = 1
    content_start_idx = current_idx
 

    while current_idx < len(text) and brace_depth > 0:
        char = text[current_idx]
        if char == "{":
            brace_depth += 1
        elif char == "}":
            brace_depth -= 1
        current_idx += 1
 
    
    if brace_depth != 0:
        return None
 
    
    return text[content_start_idx:current_idx-1]

In [17]:
extracted_answer = get_last_boxed(model_answer)
print(extracted_answer)

\dfrac{14}{3}


In [18]:
display(Math(r"\dfrac{14}{3}"))

<IPython.core.display.Math object>

In [19]:
Fallback = Literal[
    "number_then_full", # (default): pick the last simple number, else the whole text
    "number_only", # pick the last simple number, else return an empty string "";
    "none" # extract only boxed content, else return empty string "".
]

def extract_final_candidate(text: str, fallback: Fallback="number_then_full") -> str:
    
    result = ""
 
    if text:
        boxed = get_last_boxed(text.strip())
        if boxed:
            result = boxed.strip().strip("$ ")
        elif fallback in ("number_then_full", "number_only"):
            m = RE_NUMBER.findall(text)
            if m:
                result = m[-1]
            elif fallback == "number_then_full":
                result = text
    return result

In [20]:
print(extract_final_candidate(model_answer))
print(extract_final_candidate(r"\boxed{ 14/3. }"))
print(extract_final_candidate("abc < > 14/3 abc"))

\dfrac{14}{3}
14/3.
14/3


# 3.5 Normalizing the extracted answer

In [21]:
LATEX_FIXES = [
    (r"\\left\s*", ""),
    (r"\\right\s*", ""),
    (r"\\,|\\!|\\;|\\:", ""),
    (r"\\cdot", "*"),
    (r"\u00B7|\u00D7", "*"),
    (r"\\\^\\circ", ""),
    (r"\\dfrac", r"\\frac"),
    (r"\\tfrac", r"\\frac"),
    (r"°", ""),
]
 
RE_SPECIAL = re.compile(r"<\|[^>]+?\|>")

In [22]:
def normalize_text(text: str) -> str:
    if not text:
        return ""
    text = RE_SPECIAL.sub(repl="", string=text).strip()
 
    text = re.sub(pattern=r"\^\s*\{\s*\\circ\s*\}", repl="", string=text)
    text = re.sub(pattern=r"\^\s*\\circ", repl="", string=text)
    text = text.replace("°", "")
 
    match = re.match(pattern=r"^\\text\{(?P<x>.+?)\}$", string=text)
    if match:
        text = match.group("x")
 
    text = re.sub(pattern=r"\\\(|\\\)|\\\[|\\\]", repl="", string=text)
 
    
    for pat, rep in LATEX_FIXES:
        text = re.sub(pat, rep, text)
 

    text = text.replace("\\%", "%").replace("$", "").replace("%", "")
    text = re.sub(
        pattern=r"\\sqrt\s*\{([^}]*)\}",
        repl=lambda match: f"sqrt({match.group(1)})",
        string=text
    )
    text = re.sub(
        pattern=r"\\sqrt\s+([^\\\s{}]+)",
        repl=lambda match: f"sqrt({match.group(1)})",
        string=text
    )
 

    text = re.sub(
        r"\\frac\s*\{([^{}]+)\}\s*\{([^{}]+)\}",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )
    text = re.sub(
        r"\\frac\s+([^\s{}]+)\s+([^\s{}]+)",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )
 

    text = text.replace("^", "**")
    text = re.sub(
        r"(?<=\d)\s+(\d+/\d+)",
        lambda match: "+" + match.group(1),
        text,
    )
 

    text = re.sub(
        r"(?<=\d),(?=\d\d\d(\D|$))",
        "",
        text,
    )
 
    return text.replace("{", "").replace("}", "").strip().lower()

In [23]:
print(normalize_text(extract_final_candidate(model_answer)))
print(normalize_text(r"\text{\[\frac{14}{3}\]}"))

(14)/(3)
(14)/(3)


# 3.6 Verifying mathematical equivalence

In [24]:
def sympy_parser(expr: str) -> Optional[Expr]:
    try:
        e = spp.parse_expr(
            s=expr,
            transformations=(*spp.standard_transformations, spp.implicit_multiplication_application),
            evaluate=True,
        )
        return e if isinstance(e, Expr) else None
    except (SympifyError, SyntaxError, TypeError, IndexError, TokenError):
        return None

In [25]:
print(sympy_parser(expr=normalize_text(extract_final_candidate(text=model_answer))))
print(sympy_parser(expr="28/6"))

14/3
14/3


In [26]:
def equality_check(expr_gtruth: str, expr_pred: str) -> bool:
    if expr_gtruth == expr_pred:
        return True

    gtruth, pred = sympy_parser(expr_gtruth), sympy_parser(expr_pred)
 
    if gtruth is not None and pred is not None:
        try:
            return simplify(expr=gtruth - pred) == 0
        except (SympifyError, TypeError):
            pass
 
    return False

In [27]:
print(equality_check(expr_gtruth=normalize_text(text="13/4."), expr_pred=normalize_text(text=r"(13)/(4)")))
print(equality_check(expr_gtruth=normalize_text(text="0.5"), expr_pred=normalize_text(text=r"(1)/(2)")))
print(equality_check(expr_gtruth=normalize_text(text="14/3"), expr_pred=normalize_text(text="15/3")))
print(equality_check(expr_gtruth=normalize_text(text="(14/3, 2/3)"), expr_pred=normalize_text(text="(14/3, 4/6)")))

True
True
False
False


# 3.7 Grading answers

In [28]:
def split_into_parts(text: str) -> List[str]:
    result = [text]
 
    if text:
        if (
            len(text) >= 2
            and text[0] in "([" and text[-1] in ")]"
            and "," in text[1:-1]
        ):
            items = [p.strip() for p in text[1:-1].split(",")]
            if all(items):
                result = items
    else:
        result = []
 
    return result

In [29]:
split_into_parts(normalize_text(r"(14/3, 2/3)"))

['14/3', '2/3']

In [30]:
def grade_answer(pred_text: str, gt_text: str) -> bool:
    result = False
    if pred_text is not None and gt_text is not None:
        gt_parts = split_into_parts(normalize_text(gt_text))
        pred_parts = split_into_parts(normalize_text(pred_text))
 
        if (gt_parts and pred_parts and len(gt_parts) == len(pred_parts)):
            result = all(equality_check(gt, pred) for gt, pred in zip(gt_parts, pred_parts))
 
    return result
 

In [31]:
print(grade_answer(pred_text="14/3", gt_text=r"\frac{14}{3}"))
print(grade_answer(pred_text=r"(14/3, 2/3)", gt_text="(14/3, 4/6)"))

True
True


In [32]:
tests: List[Tuple[str, str, str, bool]] = [
        ("check_1", "3/4", r"\frac{3}{4}", True),
        ("check_2", "(3)/(4)", r"3/4", True),
        ("check_3", r"\frac{\sqrt{8}}{2}", "sqrt(2)", True),
        ("check_4", r"\( \frac{1}{2} + \frac{1}{6} \)", "2/3", True),
        ("check_5", "(1, 2)", r"(1,2)", True),
        ("check_6", "(2, 1)", "(1, 2)", False),
        ("check_7", "(1, 2, 3)", "(1, 2)", False),
        ("check_8", "0.5", "1/2", True),
        ("check_9", "0.3333333333", "1/3", False),
        ("check_10", "1,234/2", "617", True),
        ("check_11", r"\text{2/3}", "2/3", True),
        ("check_12", "50%", "1/2", False),
        ("check_13", r"2\cdot 3/4", "3/2", True),
        ("check_14", r"90^\circ", "90", True),
        ("check_15", r"\left(\frac{3}{4}\right)", "3/4", True)
]

In [33]:
def run_demos_table(tests: List[Tuple[str, str, str, bool]]) -> None:
    header = ("Test", "Expect", "Got", "Status")
    rows = []
    for name, pred, gtruth, expect in tests:
        got = grade_answer(pred, gtruth)
        status = "PASS" if got == expect else "FAIL"
        rows.append((name, str(expect), str(got), status))
 
    data = [header] + rows
    
    col_widths = [
        max(len(row[i]) for row in data)
        for i in range(len(header))
    ]
 
    for row in data:
        line = " | ".join(
            row[i].ljust(col_widths[i])
            for i in range(len(header))
        )
        print(line)
 
    passed = sum(r[3] == "PASS" for r in rows)
    print(f"\nPassed {passed}/{len(rows)}")

In [34]:
run_demos_table(tests)

Test     | Expect | Got   | Status
check_1  | True   | True  | PASS  
check_2  | True   | True  | PASS  
check_3  | True   | True  | PASS  
check_4  | True   | True  | PASS  
check_5  | True   | True  | PASS  
check_6  | False  | False | PASS  
check_7  | False  | False | PASS  
check_8  | True   | True  | PASS  
check_9  | False  | False | PASS  
check_10 | True   | True  | PASS  
check_11 | True   | True  | PASS  
check_12 | False  | False | PASS  
check_13 | True   | True  | PASS  
check_14 | True   | True  | PASS  
check_15 | True   | True  | PASS  

Passed 15/15


Exercise 3.1: Adding more test cases

In [35]:
tests2: List[Tuple[str, str, str, bool]] = [
        ("check_1", "3/4", r"\frac{3}{4}", True),
        ("check_2", "(3)/(5)", r"3/4", False),
        ("check_3", r"\sqrt{3}", "1.73", False),
        ("check_4", r"\frac{1}{0}", "1/0", True),
        ("check_5", r"+inf", "oo", True),
]

In [36]:
run_demos_table(tests2)

Test    | Expect | Got   | Status
check_1 | True   | True  | PASS  
check_2 | False  | False | PASS  
check_3 | False  | False | PASS  
check_4 | True   | False | FAIL  
check_5 | True   | False | FAIL  

Passed 3/5


# 3.8 Loading the evaluation dataset

In [38]:
local_path = Path("math500_test.json")
url = (
    "https://raw.githubusercontent.com/rasbt/reasoning-from-scratch/"
    "main/ch03/01_main-chapter-code/math500_test.json"
)
 
if local_path.exists():
    with local_path.open("r", encoding="utf-8") as f:
        math_data = json.load(f)
else:
    with urlopen(url) as f:
        math_data = json.load(f)
 
print("Number of entries:", len(math_data))

Number of entries: 500


In [40]:
pprint(math_data[0])

{'answer': '\\left( 3, \\frac{\\pi}{2} \\right)',
 'level': 2,
 'problem': 'Convert the point $(0,3)$ in rectangular coordinates to polar '
            'coordinates.  Enter your answer in the form $(r,\\theta),$ where '
            '$r > 0$ and $0 \\le \\theta < 2 \\pi.$',
 'solution': 'We have that $r = \\sqrt{0^2 + 3^2} = 3.$  Also, if we draw the '
             'line connecting the origin and $(0,3),$ this line makes an angle '
             'of $\\frac{\\pi}{2}$ with the positive $x$-axis.\n'
             '\n'
             '[asy]\n'
             'unitsize(0.8 cm);\n'
             '\n'
             'draw((-0.5,0)--(3.5,0));\n'
             'draw((0,-0.5)--(0,3.5));\n'
             'draw(arc((0,0),3,0,90),red,Arrow(6));\n'
             '\n'
             'dot((0,3), red);\n'
             'label("$(0,3)$", (0,3), W);\n'
             'dot((3,0), red);\n'
             '[/asy]\n'
             '\n'
             'Therefore, the polar coordinates are $\\boxed{\\left( 3, '
             '\\frac

# 3.9 Evaluating the model