In [7]:
from pathlib import Path
import torch
from reasoning_from_scratch.ch02 import (
        get_device
)
from reasoning_from_scratch.qwen3 import (
        download_qwen3_small,
        Qwen3Tokenizer,
        Qwen3Model,
        QWEN_CONFIG_06_B
)
from reasoning_from_scratch.qwen3 import KVCache


In [3]:
def load_model_and_tokenizer(which_model, device, use_compile, local_dir="qwen3"):

    if which_model == "base":
        download_qwen3_small(kind="base", tokenizer_only=False, out_dir=local_dir)
        tokenizer_path = Path(local_dir) / "tokenizer-base.json"
        model_path = Path(local_dir) / "qwen3-0.6B-base.pth"
        tokenizer = Qwen3Tokenizer(tokenizer_file_path=tokenizer_path)

    elif which_model == "reasoning":
        download_qwen3_small(kind="reasoning", tokenizer_only=False, out_dir=local_dir)
        tokenizer_path = Path(local_dir) / "tokenizer-reasoning.json"
        model_path = Path(local_dir) / "qwen3-0.6B-reasoning.pth"
        tokenizer = Qwen3Tokenizer(
                tokenizer_file_path=tokenizer_path,
                apply_chat_template=True,
                add_generation_prompt=True,
                add_thinking=True,
        )
    else:
        raise ValueError(f"Invalid choice: which_model={which_model}")
    
    model = Qwen3Model(QWEN_CONFIG_06_B)
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    if use_compile: #Optionally set to true to enable model compilation
        torch._dynamo.config.allow_unspec_int_on_nn_module = True
        model = torch.compile(model)
    return model, tokenizer


In [4]:
WHICH_MODEL = "base" #Uses the base model, similar to chapter 2, by default
device = get_device()
model, tokenizer = load_model_and_tokenizer(
        which_model=WHICH_MODEL,
        device=device,
        use_compile=False
)

Using NVIDIA CUDA GPU
✓ qwen3\qwen3-0.6B-base.pth already up-to-date


  model.load_state_dict(torch.load(model_path))


In [5]:
@torch.inference_mode()
def generate_text_basic_stream_cache(model, input_ids, max_new_tokens, eos_token_id=None):
    model.eval()

    cache = KVCache(n_layers=model.cfg['n_layers'])
    model.reset_kv_cache()

    out = model(input_ids, cache=cache)[:, -1]

    for _ in range(max_new_tokens):
        next_token = torch.argmax(out, dim=-1, keepdim=True)

        if (eos_token_id is not None
                and next_token.item() == eos_token_id):
            break

        yield next_token  # Yield each token as it's generated

        out = model(next_token, cache=cache)[:, -1]
        
    


In [8]:
prompt = ( #MATH PROBLEM
    r"If $a+b=3$ and $ab=\tfrac{13}{6}$, "
    r"what is the value of $a^2+b^2$?"
)

input_ids = torch.tensor(
    tokenizer.encode(prompt),
    device=device
).unsqueeze(0)

all_token_ids = []

for token in generate_text_basic_stream_cache(model, input_ids, max_new_tokens=2048, eos_token_id=tokenizer.eos_token_id):
    token_id = token.squeeze(0)
    decoded_id = tokenizer.decode([token_id])
    print(
        decoded_id,
        end='',
        flush=True
    )

    all_token_ids.append(token_id)

all_tokens = tokenizer.decode(all_token_ids)


 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

\[
a^2 + b^2 = (a + b)^2 - 2ab
\]

**Step 1:** Substitute the given values into the equation.

\[
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
\]

**Step 2:** Calculate \( (3)^2 \).

\[
(3)^2 = 9
\]

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

\[
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
\]

**Step 4:** Subtract the second result from the first.

\[
a^2 + b^2 = 9 - \frac{13}{3}
\]

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

\[
9 = \frac{27}{3}
\]

\[
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
\]

**Final Answer:**

\[
\boxed{\dfrac{14}{3}}
\]

In [9]:
from IPython.display import Latex, display
display(Latex(all_tokens))

<IPython.core.display.Latex object>

Now we can prepare the wrapper function

In [10]:
def generate_text_stream_concat(model, tokenizer, prompt, device, max_new_tokens, verbose=False):

    input_ids = torch.tensor(
        tokenizer.encode(prompt),
        device=device
    ).unsqueeze(0)

    all_token_ids = []

    for token in generate_text_basic_stream_cache(model, input_ids, max_new_tokens=max_new_tokens, eos_token_id=tokenizer.eos_token_id):
        token_id = token.squeeze(0)
        all_token_ids.append(token_id)
        

        if verbose:
            decoded_id = tokenizer.decode([token_id])
            print(
                decoded_id,
                end='',
                flush=True
            )

    return tokenizer.decode(all_token_ids)


In [14]:
model_answer = generate_text_stream_concat(model, tokenizer, prompt, device, max_new_tokens=2048, verbose=True)

 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

\[
a^2 + b^2 = (a + b)^2 - 2ab
\]

**Step 1:** Substitute the given values into the equation.

\[
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
\]

**Step 2:** Calculate \( (3)^2 \).

\[
(3)^2 = 9
\]

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

\[
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
\]

**Step 4:** Subtract the second result from the first.

\[
a^2 + b^2 = 9 - \frac{13}{3}
\]

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

\[
9 = \frac{27}{3}
\]

\[
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
\]

**Final Answer:**

\[
\boxed{\dfrac{14}{3}}
\]

In [13]:
def get_last_boxed(text):
    boxed_start_idx = text.rfind(r"\boxed") 
    if boxed_start_idx == -1:
        return None
    current_idx = boxed_start_idx + len(r"\boxed") 

    while current_idx < len(text) and text[current_idx].isspace():
        current_idx += 1
    
    if current_idx >= len(text) or text[current_idx] != "{":
        return None
    
    current_idx += 1
    brace_depth = 1
    content_start_idx = current_idx
    
    while current_idx < len(text) and brace_depth > 0:
        char = text[current_idx]
        if char == "{":
            brace_depth += 1
        elif char == "}":
            brace_depth -= 1
        current_idx += 1

    if brace_depth != 0: 
        return None
    
    return text[content_start_idx:current_idx-1] 

In [19]:
from IPython.display import Math

extracted_answer = get_last_boxed(model_answer)
display(Math(extracted_answer))

<IPython.core.display.Math object>

We can also handle cases in which the model fails to format a correct answer

In [20]:
import re
RE_NUMBER = re.compile( 
        r"-?(?:\d+/\d+|\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)"
    )
def extract_final_candidate(text, fallback="number_then_full"):
    result = "" 

    if text: 
        boxed = get_last_boxed(text.strip())

        if boxed:
            result = boxed.strip().strip("$ ")
    
        elif fallback in ("number_then_full", "number_only"):
            m = RE_NUMBER.findall(text)

            if m:
                result = m[-1] 
            elif fallback == "number_then_full":
                result = text 

    return result


In [22]:
final_candidate = extract_final_candidate(r"\boxed{ 14/3. }")
display(Math(final_candidate))

<IPython.core.display.Math object>

In [23]:
final_candidate = extract_final_candidate("abc < > 14/3 abc") #last number is 14/3
display(Math(final_candidate))

<IPython.core.display.Math object>

Now we can NORMALIZE the extracted answer

In [24]:
import re

LATEX_FIXES = [  # A
    (r"\\left\s*", ""),
    (r"\\right\s*", ""),
    (r"\\,|\\!|\\;|\\:", ""),
    (r"\\cdot", "*"),
    (r"\u00B7|\u00D7", "*"),
    (r"\\\^\\circ", ""),
    (r"\\dfrac", r"\\frac"),
    (r"\\tfrac", r"\\frac"),
    (r"°", ""),
]
RE_SPECIAL = re.compile(r"<\|[^>]+?\|>")  # B
SUPERSCRIPT_MAP = {
    "⁰": "0", "¹": "1", "²": "2", "³": "3", "⁴": "4",  # C
    "⁵": "5", "⁶": "6", "⁷": "7", "⁸": "8", "⁹": "9",  # C
    "⁺": "+", "⁻": "-", "⁽": "(", "⁾": ")",  # C
}

def normalize_text(text):
    if not text:
        return ""
    text = RE_SPECIAL.sub("", text).strip()
    # D
    match = re.match(r"^[A-Za-z]\s*[.:]\s*(.+)$", text)
    if match:
        text = match.group(1)
    text = re.sub(r"\^\s*\{\s*\\circ\s*\}", "", text)  # D
    text = re.sub(r"\^\s*\\circ", "", text)  # E
    text = text.replace("°", "")  # E
    match = re.match(r"^\\text\{(?P<x>.+?)\}$", text)  # F
    if match:
        text = match.group("x")
    text = re.sub(r"\\\(|\\\)|\\\[|\\\]", "", text)  # G
    for pat, rep in LATEX_FIXES:  # H
        text = re.sub(pat, rep, text)
    
    def convert_superscripts(s, base=None):
        converted = "".join(
            SUPERSCRIPT_MAP[ch] if ch in SUPERSCRIPT_MAP else ch
            for ch in s
        )
        if base is None:
            return converted
        return f"{base}**{converted}"
    
    text = re.sub(
        r"([0-9A-Za-z\)\]\}])([⁰¹²³⁴⁵⁶⁷⁸⁹⁺⁻]+)",
        lambda m: convert_superscripts(m.group(2), base=m.group(1)),
        text,
    )
    text = convert_superscripts(text)
    # I
    text = text.replace("\\%", "%").replace("$", "").replace("%", "")
    text = re.sub(
        r"\\sqrt\s*\{([^}]*)\}",
        lambda match: f"sqrt({match.group(1)})",
        text,
    )
    text = re.sub(
        r"\\sqrt\s+([^\\\s{}]+)",
        lambda match: f"sqrt({match.group(1)})",
        text,
    )
    # J
    text = re.sub(
        r"\\frac\s*\{([^{}]+)\}\s*\{([^{}]+)\}",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )
    text = re.sub(
        r"\\frac\s+([^\s{}]+)\s+([^\s{}]+)",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )
    # K
    text = text.replace("^", "**")
    text = re.sub(
        r"(?<=\d)\s+(\d+/\d+)",
        lambda match: "+" + match.group(1),
        text,
    )
    # L
    text = re.sub(
        r"(?<=\d),(?=\d\d\d(\D|$))",
        "",
        text,
    )
    return text.replace("{", "").replace("}", "").strip().lower()


In [25]:
print(normalize_text(extract_final_candidate(model_answer)))

(14)/(3)


In [26]:
print(normalize_text(r"\text{\[\frac{14}{3}\]}"))


(14)/(3)


We can now verify the mathematical equivalence between the extracted answer from LLm and a GT

In [27]:
from sympy.parsing import sympy_parser as spp
from sympy.core.sympify import SympifyError
from sympy.polys.polyerrors import PolynomialError
from tokenize import TokenError


In [33]:
def sympy_parser(expr):
    try:
        return spp.parse_expr(expr, transformations=(

            *spp.standard_transformations, #like handling parenthesis invariant symbols
            spp.implicit_multiplication_application, #allow omitted mul symbols, like 2y == 2*y

            ),
        evaluate=True,
        )
    except (SympifyError, SyntaxError, TypeError, AttributeError,
            IndexError, TokenError, ValueError, PolynomialError):
        return None




In [None]:
print(sympy_parser(
    normalize_text(
        extract_final_candidate(
            model_answer))))

14/3


In [None]:
print(sympy_parser("28/6")) #normalized by sympy

14/3


We can noe build the equality function

In [31]:
from sympy import simplify

def equality_check(expr_gt, expr_pred):

    if expr_gt == expr_pred:
        return True
    
    gt, pred = sympy_parser(expr_gt), sympy_parser(expr_pred)

    if gt is not None and pred is not None:

        try:
            return simplify(gt - pred) == 0 #for example, 14/3 and 28/6
        
        except(SympifyError, TypeError):
            return False
    
    return False

In [34]:
print(equality_check(
    normalize_text("13/4."),
    normalize_text(r"(13)/(4)")
))

True


In [35]:
print(equality_check(
    normalize_text("0.5"),
    normalize_text(r"(1)/(2)")
))

True


In [36]:
print(equality_check(
    normalize_text("14/3"),
    normalize_text("15/3")
))

False


In [None]:
print(equality_check(
    normalize_text("(14/3, 2/3)"),
    normalize_text("(14/3, 4/6)")
))
#it cannot handle tuples as of now

False
