# 4.2 Loading a pre-trained model

In [23]:
import torch
import time
import json
import re
from pathlib import Path
from qwen3 import Qwen3Model
from utils import render_prompt, generate_text_basic_stream_cache, Qwen3Tokenizer, load_model_and_tokenizer
from IPython.display import Latex, display
from typing import List, Callable, Optional, Tuple, TypedDict, Literal
from sympy import Expr, simplify
from sympy.parsing import sympy_parser as spp
from sympy.core.sympify import SympifyError
from tokenize import TokenError
from urllib.request import urlopen

In [2]:
def set_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device(device="cuda")
    elif torch.backends.mps.is_available():
        return torch.device(device="mps")
    else:
        return torch.device(device="cpu")


device = set_device()
print(f"Using device: {device}")

Using device: mps


In [3]:
WHICH_MODEL = "base"

In [4]:
raw_prompt = (
    "Half the value of $3x-9$ is $x+37$. "
    "What is the value of $x$?"
)

prompt = render_prompt(raw_prompt)
print(prompt)

You are a helpful math assistant.
Solve the problem and write the final result on a new line as:
\boxed{ANSWER}

Problem:
Half the value of $3x-9$ is $x+37$. What is the value of $x$?

Answer:


In [5]:
def generate_text_stream_concat_flex(model: Qwen3Model, tokenizer: Qwen3Tokenizer, prompt: str,  device: torch.device,
                                     max_new_tokens: int, verbose=False, generate_func=None, **generate_kwargs) -> str:

    if generate_func is None:
        generate_func = generate_text_basic_stream_cache

    input_ids = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)

    generated_ids = []
    for token in generate_func(
        model=model,
        token_ids=input_ids,
        max_new_tokens=max_new_tokens,
        eos_token_id=tokenizer.eos_token_id,
        **generate_kwargs,
    ):
        next_token_id = token.squeeze(0)
        generated_ids.append(next_token_id.item())

        if verbose:
            print(
                tokenizer.decode(next_token_id.tolist()),
                end="",
                flush=True
            )
    return tokenizer.decode(generated_ids)

In [6]:
model, tokenizer = load_model_and_tokenizer(which_model="base", device=device, use_compile=False)

✓ qwen3/qwen3-0.6B-base.pth already up-to-date
✓ qwen3/tokenizer-base.json already up-to-date


In [7]:
response = generate_text_stream_concat_flex(
    model, tokenizer, prompt, device,
    max_new_tokens=2048, verbose=True,
    generate_func=generate_text_basic_stream_cache
)

 $x=100$ To solve the problem, we need to set up an equation based on the given information and then solve for \( x \).

The problem states that half the value of \( 3x - 9 \) is \( x + 37 \). We can express this as an equation:

\[
\frac{1}{2}(3x - 9) = x + 37
\]

First, we'll eliminate the fraction by multiplying both sides of the equation by 2:

\[
3x - 9 = 2(x + 37)
\]

Next, we'll distribute the 2 on the right side of the equation:

\[
3x - 9 = 2x + 74
\]

Now, we'll isolate \( x \) by subtracting \( 2x \) from both sides:

\[
3x - 2x - 9 = 74
\]

This simplifies to:

\[
x - 9 = 74
\]

Next, we'll add 9 to both sides to solve for \( x \):

\[
x = 74 + 9
\]

So:

\[
x = 83
\]

However, the problem states that the value of \( x \) is \( 100 \). Let's verify this by substituting \( x = 100 \) back into the original equation to ensure it holds true.

The original equation is:

\[
\frac{1}{2}(3x - 9) = x + 37
\]

Substituting \( x = 100 \):

\[
\frac{1}{2}(3(100) - 9) = 100 + 37
\]

Si

In [8]:
display(Latex(response))

<IPython.core.display.Latex object>

# 4.3 Generating better responses with chain-of-thought prompting

In [9]:
prompt_cot = prompt + " \n\nExplain step by step."

response_cot = generate_text_stream_concat_flex(
    model, tokenizer, prompt_cot, device,
    max_new_tokens=2048, verbose=True,
)

 To solve the problem, we need to find the value of \( x \) such that half the value of \( 3x - 9 \) is equal to \( x + 37 \). Let's break this down step by step.

### Step 1: Set up the equation
We are given that half the value of \( 3x - 9 \) is equal to \( x + 37 \). This can be written as:
\[
\frac{1}{2}(3x - 9) = x + 37
\]

### Step 2: Eliminate the fraction
To eliminate the fraction, multiply both sides of the equation by 2:
\[
2 \cdot \frac{1}{2}(3x - 9) = 2(x + 37)
\]
Simplifying both sides:
\[
3x - 9 = 2x + 74
\]

### Step 3: Solve for \( x \)
Now, we need to isolate \( x \) on one side of the equation. Start by subtracting \( 2x \) from both sides:
\[
3x - 2x - 9 = 74
\]
Simplifying:
\[
x - 9 = 74
\]
Next, add 9 to both sides to solve for \( x \):
\[
x = 74 + 9
\]
\[
x = 83
\]

### Step 4: Write the final answer
The value of \( x \) is:
\[
\boxed{83}
\]

In [10]:
display(Latex(response_cot))

<IPython.core.display.Latex object>

Exercise 4.1: Use chain-of-thought prompting on MATH-500

In [11]:
class MathDatum(TypedDict):
    problem: str
    solution: str
    answer: str
    subject: str
    level: int
    unique_id: str

In [12]:
Fallback = Literal[
    "number_then_full", # (default): pick the last simple number, else the whole text
    "number_only", # pick the last simple number, else return an empty string "";
    "none" # extract only boxed content, else return empty string "".
]

In [13]:
def get_last_boxed(text: str) -> Optional[str]:
    boxed_start_idx = text.rfind(r"\boxed")
    if boxed_start_idx == -1:
        return None

    current_idx = boxed_start_idx + len(r"\boxed")

    while current_idx < len(text) and text[current_idx].isspace():
        current_idx += 1

    if current_idx >= len(text) or text[current_idx] != "{":
        return None

    current_idx += 1
    brace_depth = 1
    content_start_idx = current_idx

    while current_idx < len(text) and brace_depth > 0:
        char = text[current_idx]
        if char == "{":
            brace_depth += 1
        elif char == "}":
            brace_depth -= 1
        current_idx += 1


    if brace_depth != 0:
        return None

    return text[content_start_idx:current_idx-1]

In [14]:
RE_NUMBER = re.compile(pattern=r"-?(?:\d+/\d+|\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)")

In [15]:
def extract_final_candidate(text: str, fallback: Fallback="number_then_full") -> str:

    result = ""

    if text:
        boxed = get_last_boxed(text.strip())
        if boxed:
            result = boxed.strip().strip("$ ")
        elif fallback in ("number_then_full", "number_only"):
            m = RE_NUMBER.findall(text)
            if m:
                result = m[-1]
            elif fallback == "number_then_full":
                result = text
    return result

In [16]:
def split_into_parts(text: str) -> List[str]:
    result = [text]

    if text:
        if (
            len(text) >= 2
            and text[0] in "([" and text[-1] in ")]"
            and "," in text[1:-1]
        ):
            items = [p.strip() for p in text[1:-1].split(",")]
            if all(items):
                result = items
    else:
        result = []

    return result

In [17]:
LATEX_FIXES = [
    (r"\\left\s*", ""),
    (r"\\right\s*", ""),
    (r"\\,|\\!|\\;|\\:", ""),
    (r"\\cdot", "*"),
    (r"\u00B7|\u00D7", "*"),
    (r"\\\^\\circ", ""),
    (r"\\dfrac", r"\\frac"),
    (r"\\tfrac", r"\\frac"),
    (r"°", ""),
]

RE_SPECIAL = re.compile(r"<\|[^>]+?\|>")

In [18]:
def normalize_text(text: str) -> str:
    if not text:
        return ""
    text = RE_SPECIAL.sub(repl="", string=text).strip()
    text = re.sub(pattern=r"\^\s*\{\s*\\circ\s*\}", repl="", string=text)
    text = re.sub(pattern=r"\^\s*\\circ", repl="", string=text)
    text = text.replace("°", "")

    match = re.match(pattern=r"^\\text\{(?P<x>.+?)\}$", string=text)
    if match:
        text = match.group("x")

    text = re.sub(pattern=r"\\\(|\\\)|\\\[|\\\]", repl="", string=text)

    for pat, rep in LATEX_FIXES:
        text = re.sub(pat, rep, text)

    text = text.replace("\\%", "%").replace("$", "").replace("%", "")
    text = re.sub(
        pattern=r"\\sqrt\s*\{([^}]*)\}",
        repl=lambda match: f"sqrt({match.group(1)})",
        string=text
    )
    text = re.sub(
        pattern=r"\\sqrt\s+([^\\\s{}]+)",
        repl=lambda match: f"sqrt({match.group(1)})",
        string=text
    )
    text = re.sub(
        r"\\frac\s*\{([^{}]+)\}\s*\{([^{}]+)\}",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )
    text = re.sub(
        r"\\frac\s+([^\s{}]+)\s+([^\s{}]+)",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )

    text = text.replace("^", "**")
    text = re.sub(
        r"(?<=\d)\s+(\d+/\d+)",
        lambda match: "+" + match.group(1),
        text,
    )

    text = re.sub(
        r"(?<=\d),(?=\d\d\d(\D|$))",
        "",
        text,
    )

    return text.replace("{", "").replace("}", "").strip().lower()

In [19]:
def sympy_parser(expr: str) -> Optional[Expr]:
    try:
        e = spp.parse_expr(
            s=expr,
            transformations=(*spp.standard_transformations, spp.implicit_multiplication_application),
            evaluate=True,
        )
        return e if isinstance(e, Expr) else None
    except (SympifyError, SyntaxError, TypeError, IndexError, TokenError):
        return None

In [20]:
def equality_check(expr_gtruth: str, expr_pred: str) -> bool:
    if expr_gtruth == expr_pred:
        return True

    gtruth, pred = sympy_parser(expr_gtruth), sympy_parser(expr_pred)

    if gtruth is not None and pred is not None:
        try:
            return simplify(expr=gtruth - pred) == 0
        except (SympifyError, TypeError):
            pass

    return False

In [21]:
def grade_answer(pred_text: str, gt_text: str) -> bool:
    result = False
    if pred_text is not None and gt_text is not None:
        gt_parts = split_into_parts(normalize_text(gt_text))
        pred_parts = split_into_parts(normalize_text(pred_text))

        if gt_parts and pred_parts and len(gt_parts) == len(pred_parts):
            result = all(equality_check(gt, pred) for gt, pred in zip(gt_parts, pred_parts))

    return result

In [22]:
def evaluate_math500_stream_length(
    model: Qwen3Model,
    tokenizer: Qwen3Tokenizer,
    prompt_template: Callable[[str], str],
    device: torch.device,
    math_data: List[MathDatum],
    out_path: Optional[Path]=None,
    max_new_tokens: int=512,
    verbose: bool=False
) -> Tuple[int, int, float]:

    if out_path is None:
        dev_name = str(device).replace(":", "-")
        out_path = Path(f"math500_{WHICH_MODEL}-{dev_name}.jsonl")

    num_examples = len(math_data)
    num_correct = 0
    print(f"MATH-500: 0/{num_examples}", end="\r", flush=True)

    start_time = time.time()
    total_gen_length = 0
    total_len = 0
    with open(file=out_path, mode="w", encoding="utf-8") as f:
        for i, row in enumerate(iterable=math_data, start=1):
            prompt = prompt_template(row["problem"])
            prompt_cot = prompt + " \n\nExplain step by step."
            gen_text = generate_text_stream_concat_flex(
                model=model, tokenizer=tokenizer, prompt=prompt_cot, device=device,
                max_new_tokens=max_new_tokens, verbose=verbose, generate_func=generate_text_basic_stream_cache
            )

            extracted = extract_final_candidate(text=gen_text)
            is_correct = grade_answer(pred_text=extracted, gt_text=row["answer"])
            num_correct += int(is_correct)
            gen_length = len(gen_text)
            total_gen_length += gen_length
            total_len += len(tokenizer.encode(gen_text))

            record = {
                "index": i,
                "problem": row["problem"],
                "gtruth_answer": row["answer"],
                "generated_text": gen_text,
                "generated_length": gen_length,
                "extracted": extracted,
                "correct": bool(is_correct),
            }
            f.write(json.dumps(record, ensure_ascii=False) + "\n")

            if verbose:
                print(
                    f"\n\n{'='*50}\nMATH-500: {i}/{num_examples}\n"
                    f"{'='*50}\nExtracted: {extracted}\n"
                    f"Expected:  {row['answer']}\n"
                    f"Correct so far: {num_correct}\n{'-'*50}"
                )
            else:
                print(
                    f"MATH-500: {i}/{num_examples}",
                    end="\r", flush=True
                )


    seconds_elapsed = time.time() - start_time
    acc = num_correct / num_examples if num_examples else 0.0
    print(f"\nAccuracy: {acc*100:.1f}% ({num_correct}/{num_examples})")
    print(f"Total time: {seconds_elapsed/60:.1f} min")
    print(f"Logs written to: {out_path}")
    print(f"Average answer length: {(total_len / num_examples):.2f} tokens - {(total_gen_length / num_examples):.2f} characters")
    return num_correct, num_examples, acc

In [24]:
local_path = Path("math500_test.json")
url = (
    "https://raw.githubusercontent.com/rasbt/reasoning-from-scratch/"
    "main/ch03/01_main-chapter-code/math500_test.json"
)

if local_path.exists():
    with local_path.open("r", encoding="utf-8") as f:
        math_data = json.load(f)
        print("not in local path")
else:
    with urlopen(url) as f:
        math_data = json.load(f)
        print("fetched from the web")

print("Number of entries:", len(math_data))

fetched from the web
Number of entries: 500


In [25]:
print("Model: base - CoT")
num_correct, num_examples, acc = evaluate_math500_stream_length(
    model=model, tokenizer=tokenizer,
    prompt_template=render_prompt,
    device=device,
    math_data=math_data[:10],
    max_new_tokens=2048,
    verbose=False
)

Model: base - CoT
MATH-500: 10/10
Accuracy: 60.0% (6/10)
Total time: 7.1 min
Logs written to: math500_base-mps.jsonl
Average answer length: 618.80 tokens - 2011.90 characters


End of Exercise 4.1

# 4.4 Controlling output diversity with temperature scaling.