In [1]:
# Set custom cache directory for Hugging Face transformers
# This avoids using the default home directory and stores models in scratch space
import os
os.environ["TRANSFORMERS_CACHE"] = "/scratch/gilbreth/sramishe"

In [2]:
# Import required libraries
import re
import json
import time
from dataclasses import dataclass
from typing import List, Optional
from pathlib import Path

import torch
from datasets import load_dataset  # For loading the MATH-500 benchmark dataset
from transformers import AutoTokenizer, AutoModelForCausalLM  # For loading LLM model and tokenizer
from tqdm import tqdm  # Progress bar for batch processing
from sympy import simplify  # For symbolic math comparison
from sympy.parsing import sympy_parser as spp  # For parsing mathematical expressions
from sympy.core.sympify import SympifyError  # For handling parsing errors

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# GPU Configuration Check
# Verify CUDA availability and print GPU information
print("CUDA available:", torch.cuda.is_available())

# Number of GPUs
print("Number of GPUs:", torch.cuda.device_count())

# Name of each GPU (should show NVIDIA A100 80GB PCIe)
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

# Current GPU device index
print("Current device:", torch.cuda.current_device())

CUDA available: True
Number of GPUs: 1
GPU 0: NVIDIA A100 80GB PCIe
Current device: 0


In [4]:

@dataclass
class ReasoningConfig:
    """Configuration for reasoning model evaluation"""
    model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"  # The reasoning model to benchmark
    dataset_name: str = "HuggingFaceH4/MATH-500"  # Math benchmark dataset with 500 problems
    dataset_split: str = "test"  # Use test split for evaluation
    batch_size: int = 2  # Number of samples per batch (reduced due to large model size)
    max_new_tokens: int = 512  # Maximum tokens to generate for reasoning and answer
    limit_eval_samples: Optional[int] = 500  # Limit samples for faster testing (None = use all)
    output_file: Optional[str] = "math500_results_500.jsonl"  # File to save detailed results



In [5]:

# ============================================================================
# HELPER FUNCTIONS FROM RASCHKA'S CODE
# ============================================================================

def get_last_boxed(text):
    """
    Extract the content inside the last \\boxed{...} in the text.
    Handles nested braces correctly.
    
    Args:
        text: Generated text that may contain \\boxed{answer}
    
    Returns:
        Content inside the last \\boxed{} or None if not found
    """
    # Find the last occurrence of "\boxed"
    boxed_start_idx = text.rfind(r"\boxed")
    if boxed_start_idx == -1:
        return None

    # Get position after "\boxed"
    current_idx = boxed_start_idx + len(r"\boxed")

    # Skip any whitespace after "\boxed"
    while current_idx < len(text) and text[current_idx].isspace():
        current_idx += 1

    # Expect an opening brace "{"
    if current_idx >= len(text) or text[current_idx] != "{":
        return None

    # Parse the braces with nesting support
    current_idx += 1
    brace_depth = 1
    content_start_idx = current_idx

    while current_idx < len(text) and brace_depth > 0:
        char = text[current_idx]
        if char == "{":
            brace_depth += 1
        elif char == "}":
            brace_depth -= 1
        current_idx += 1

    # Check for unbalanced braces
    if brace_depth != 0:
        return None

    # Extract content inside the outermost braces
    return text[content_start_idx:current_idx-1]


# Regex pattern for extracting numbers (including fractions, decimals, scientific notation)
RE_NUMBER = re.compile(r"-?(?:\d+/\d+|\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)")

def extract_final_candidate(text, fallback="number_then_full"):
    """
    Extract the final answer from generated text.
    Prefers \\boxed{} format, falls back to extracting last number.
    
    Args:
        text: Generated text from the model
        fallback: Strategy when \\boxed{} not found ("number_then_full" or "number_only")
    
    Returns:
        Extracted answer string
    """
    result = ""

    if text:
        # Prefer the last boxed expression if present
        boxed = get_last_boxed(text.strip())
        if boxed:
            result = boxed.strip().strip("$ ")
        # If no boxed expression, try fallback
        elif fallback in ("number_then_full", "number_only"):
            m = RE_NUMBER.findall(text)
            if m:
                # Use last number found
                result = m[-1]
            elif fallback == "number_then_full":
                # Return full text if no number found
                result = text
    return result


# LaTeX formatting patterns to normalize
LATEX_FIXES = [
    (r"\\left\s*", ""),
    (r"\\right\s*", ""),
    (r"\\,|\\!|\\;|\\:", ""),
    (r"\\cdot", "*"),
    (r"\u00B7|\u00D7", "*"),
    (r"\\\^\\circ", ""),
    (r"\\dfrac", r"\\frac"),
    (r"\\tfrac", r"\\frac"),
    (r"°", ""),
]

# Regex to strip chat special tokens like <|assistant|>
RE_SPECIAL = re.compile(r"<\|[^>]+?\|>")

def normalize_text(text):
    """
    Normalize mathematical text by handling LaTeX, unicode, fractions, etc.
    This is crucial for comparing model outputs with ground truth answers.
    
    Args:
        text: Raw mathematical expression (may contain LaTeX)
    
    Returns:
        Normalized lowercase string suitable for comparison
    """
    if not text:
        return ""
    
    text = RE_SPECIAL.sub("", text).strip()
    
    # Map for converting unicode superscripts to normal characters
    SUPERSCRIPT_MAP = {
        "⁰": "0", "¹": "1", "²": "2", "³": "3", "⁴": "4",
        "⁵": "5", "⁶": "6", "⁷": "7", "⁸": "8", "⁹": "9",
        "⁺": "+", "⁻": "-", "⁽": "(", "⁾": ")",
    }
    
    # Remove angle-degree markers
    text = re.sub(r"\^\s*\{\s*\\circ\s*\}", "", text)   # ^{\circ}
    text = re.sub(r"\^\s*\\circ", "", text)             # ^\circ
    text = text.replace("°", "")                        # Unicode degree

    # Unwrap \text{...} if the whole string is wrapped
    match = re.match(r"^\\text\{(?P<x>.+?)\}$", text)
    if match:
        text = match.group("x")

    # Strip inline/display math wrappers \( \) \[ \]
    text = re.sub(r"\\\(|\\\)|\\\[|\\\]", "", text)

    # Apply LaTeX canonicalization
    for pat, rep in LATEX_FIXES:
        text = re.sub(pat, rep, text)

    def convert_superscripts(s, base=None):
        """Convert unicode superscripts to exponent notation"""
        converted = "".join(
            SUPERSCRIPT_MAP[ch] if ch in SUPERSCRIPT_MAP else ch
            for ch in s
        )
        if base is None:
            return converted
        return f"{base}**{converted}"

    # Convert unicode superscripts into exponent form (e.g., 2² -> 2**2)
    text = re.sub(
        r"([0-9A-Za-z\)\]\}])([⁰¹²³⁴⁵⁶⁷⁸⁹⁺⁻]+)",
        lambda m: convert_superscripts(m.group(2), base=m.group(1)),
        text,
    )
    text = convert_superscripts(text)
    
    # Handle percentages and dollar signs
    text = text.replace("\\%", "%").replace("$", "").replace("%", "")
    
    # Convert \sqrt{...} to sqrt(...)
    text = re.sub(
        r"\\sqrt\s*\{([^}]*)\}",
        lambda match: f"sqrt({match.group(1)})",
        text,
    )
    text = re.sub(
        r"\\sqrt\s+([^\\\s{}]+)",
        lambda match: f"sqrt({match.group(1)})",
        text,
    )

    # Convert \frac{a}{b} to (a)/(b)
    text = re.sub(
        r"\\frac\s*\{([^{}]+)\}\s*\{([^{}]+)\}",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )
    text = re.sub(
        r"\\frac\s+([^\s{}]+)\s+([^\s{}]+)",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )

    # Convert exponents: ^ to **
    text = text.replace("^", "**")
    
    # Handle mixed numbers (e.g., "2 1/2" -> "2+1/2")
    text = re.sub(
        r"(?<=\d)\s+(\d+/\d+)",
        lambda match: "+" + match.group(1),
        text,
    )

    # Remove thousand separators (e.g., 1,234 -> 1234)
    text = re.sub(
        r"(?<=\d),(?=\d\d\d(\D|$))",
        "",
        text,
    )

    # Remove remaining braces and convert to lowercase
    return text.replace("{", "").replace("}", "").strip().lower()


def sympy_parser(expr):
    """
    Parse a mathematical expression using SymPy with appropriate transformations.
    
    Args:
        expr: Normalized mathematical expression string
    
    Returns:
        SymPy expression object or None if parsing fails
    """
    try:
        return spp.parse_expr(
            expr,
            transformations=(
                # Standard transformations like handling parentheses
                *spp.standard_transformations,
                # Allow omitted multiplication symbols (e.g., "2x" -> "2*x")
                spp.implicit_multiplication_application,
            ),
            # Evaluate during parsing so simple constants simplify (e.g., 2+3 -> 5)
            evaluate=True,
        )
    except (SympifyError, SyntaxError, TypeError, IndexError):
        return None


def equality_check(expr_gtruth, expr_pred):
    """
    Check if two mathematical expressions are equivalent.
    Uses both string comparison and symbolic SymPy comparison.
    
    Args:
        expr_gtruth: Ground truth expression (normalized)
        expr_pred: Predicted expression (normalized)
    
    Returns:
        True if expressions are mathematically equivalent
    """
    # First, check exact string match
    if expr_gtruth == expr_pred:
        return True

    # Parse both expressions into SymPy objects
    gtruth, pred = sympy_parser(expr_gtruth), sympy_parser(expr_pred)

    # If both parsed successfully, try symbolic comparison
    if gtruth is not None and pred is not None:
        try:
            # If the difference simplifies to 0, they are equivalent
            return simplify(gtruth - pred) == 0
        except (SympifyError, TypeError):
            pass

    return False


def split_into_parts(text):
    """
    Split tuple/list answers into individual parts for comparison.
    Example: "(1, 2, 3)" -> ["1", "2", "3"]
    
    Args:
        text: Answer text that may be a tuple or list
    
    Returns:
        List of individual answer parts
    """
    result = [text]

    if text:
        # Check if text looks like a tuple or list, e.g. "(a, b)" or "[a, b]"
        if (
            len(text) >= 2
            and text[0] in "([" and text[-1] in ")]"
            and "," in text[1:-1]
        ):
            # Split on commas inside brackets and strip whitespace
            items = [p.strip() for p in text[1:-1].split(",")]
            if all(items):
                result = items
    else:
        # If text is empty, return an empty list
        result = []

    return result


def grade_answer(pred_text, gt_text):
    """
    Grade a predicted answer against ground truth.
    Handles normalization, tuple/list splitting, and symbolic comparison.
    
    Args:
        pred_text: Model's predicted answer
        gt_text: Ground truth answer
    
    Returns:
        True if answer is correct, False otherwise
    """
    result = False  # Default outcome if checks fail

    # Only continue if both inputs are non-empty strings
    if pred_text is not None and gt_text is not None:
        # Normalize and split both answers into comparable parts
        gt_parts = split_into_parts(normalize_text(gt_text))
        pred_parts = split_into_parts(normalize_text(pred_text))

        # Ensure both sides have same number of valid parts
        if (gt_parts and pred_parts and len(gt_parts) == len(pred_parts)):
            # Check each part for mathematical equivalence
            result = all(
                equality_check(gt, pred)
                for gt, pred in zip(gt_parts, pred_parts)
            )

    return result



In [6]:

def render_prompt(question):
    """
    Build the prompt for the model following Raschka's format.
    Instructs model to output answer in \\boxed{} format.
    
    Args:
        question: The math problem to solve
    
    Returns:
        Formatted prompt string
    """
    template = (
        "You are a helpful math assistant.\n"
        "Answer the question and write the final result on a new line as:\n"
        "\\boxed{ANSWER}\n\n"
        f"Question:\n{question}\n\nAnswer:"
    )
    return template

In [7]:
def eta_progress_message(processed, total, start_time, show_eta=False, label="Progress"):
    """
    Calculate and format progress message with ETA.
    
    Args:
        processed: Number of items completed
        total: Total number of items
        start_time: Start time of the process
        show_eta: Whether to show estimated time remaining
        label: Label for the progress message
    
    Returns:
        Formatted progress string
    """
    progress = f"{label}: {processed}/{total}"
    if not show_eta or processed <= 0:
        return progress

    elapsed = time.time() - start_time
    if elapsed <= 0:
        return progress

    remaining = max(total - processed, 0)

    if processed:
        avg_time = elapsed / processed
        eta_seconds = avg_time * remaining
    else:
        eta_seconds = 0

    eta_seconds = max(int(round(eta_seconds)), 0)
    minutes, rem_seconds = divmod(eta_seconds, 60)
    hours, minutes = divmod(minutes, 60)
    
    if hours:
        eta = f"{hours}h {minutes:02d}m {rem_seconds:02d}s"
    elif minutes:
        eta = f"{minutes:02d}m {rem_seconds:02d}s"
    else:
        eta = f"{rem_seconds:02d}s"

    return f"{progress} | ETA: {eta}"



In [8]:

# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================

def generate_batch(model, tokenizer, prompts, max_new_tokens=512):
    """
    Generate responses for a batch of prompts.
    
    Args:
        model: The language model
        tokenizer: The tokenizer
        prompts: List of prompt strings
        max_new_tokens: Maximum tokens to generate
    
    Returns:
        List of generated text strings (without the input prompt)
    """
    # Tokenize all prompts
    inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
    ).to(model.device)

    # Generate responses
    with torch.no_grad():
        generated = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # Greedy decoding for reproducibility
        )

    # Extract only generated tokens (remove input prompt)
    gen_only_ids = [
        output_ids[len(input_ids):]
        for input_ids, output_ids in zip(inputs["input_ids"], generated)
    ]

    # Decode to text
    outputs = tokenizer.batch_decode(gen_only_ids, skip_special_tokens=True)
    return outputs


In [9]:


def evaluate_math500(cfg: ReasoningConfig):
    """
    Main evaluation function using Raschka's improved pipeline.
    
    Args:
        cfg: Configuration object with model and dataset settings
    
    Returns:
        Tuple of (num_correct, num_total, accuracy)
    """
    # Load the MATH-500 benchmark dataset
    dataset = load_dataset(cfg.dataset_name, split=cfg.dataset_split)

    # Limit to subset if specified
    if cfg.limit_eval_samples is not None:
        dataset = dataset.select(range(cfg.limit_eval_samples))

    num_examples = len(dataset)
    
    # Load tokenizer and model
    print(f"Loading model: {cfg.model_name}")
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        torch_dtype="auto",
        device_map="auto",
    )
    model.eval()

    # Initialize counters
    num_correct = 0
    start_time = time.time()

    # Open output file for saving detailed results
    out_path = Path(cfg.output_file) if cfg.output_file else Path("math500_results_500.jsonl")
    
    with open(out_path, "w", encoding="utf-8") as f:
        # Process in batches
        for i in tqdm(range(0, num_examples, cfg.batch_size), desc="Evaluating"):
            batch_end = min(i + cfg.batch_size, num_examples)
            batch_data = dataset[i:batch_end]
            
            # Handle both single items and batches
            if isinstance(batch_data["problem"], str):
                batch_problems = [batch_data["problem"]]
                batch_answers = [batch_data["answer"]]
            else:
                batch_problems = batch_data["problem"]
                batch_answers = batch_data["answer"]
            
            # Build prompts
            prompts = [render_prompt(q) for q in batch_problems]
            
            # Generate responses
            generated_texts = generate_batch(model, tokenizer, prompts, cfg.max_new_tokens)
            
            # Grade each response
            for idx, (problem, gt_answer, gen_text) in enumerate(zip(batch_problems, batch_answers, generated_texts)):
                # Extract answer from generated text
                extracted = extract_final_candidate(gen_text)
                
                # Grade the answer
                is_correct = grade_answer(extracted, gt_answer)
                num_correct += int(is_correct)
                
                # Save detailed record
                record = {
                    "index": i + idx + 1,
                    "problem": problem,
                    "ground_truth": gt_answer,
                    "generated_text": gen_text,
                    "extracted_answer": extracted,
                    "correct": bool(is_correct),
                }
                f.write(json.dumps(record, ensure_ascii=False) + "\n")
            
            # Print progress
            processed = batch_end
            progress_msg = eta_progress_message(
                processed, num_examples, start_time, 
                show_eta=True, label="MATH-500"
            )
            print(f"\r{progress_msg} | Correct: {num_correct}/{processed}", end="", flush=True)
    
    print()  # New line after progress
    
    # Calculate final metrics
    accuracy = num_correct / num_examples if num_examples else 0.0
    elapsed_time = time.time() - start_time
    
    # Print summary
    print(f"\n{'='*60}")
    print(f"MATH-500 Evaluation Results")
    print(f"{'='*60}")
    print(f"Total examples:     {num_examples}")
    print(f"Correct answers:    {num_correct}")
    print(f"Accuracy:           {accuracy*100:.2f}% ({num_correct}/{num_examples})")
    print(f"Total time:         {elapsed_time/60:.1f} minutes")
    print(f"Results saved to:   {out_path}")
    print(f"{'='*60}\n")
    
    return num_correct, num_examples, accuracy


# Initialize configuration
cfg_reasoning = ReasoningConfig(
    limit_eval_samples=500,  # Evaluate on 50 samples for testing
    batch_size=2,  # Process 2 at a time due to model size
    max_new_tokens=512,  # Allow enough tokens for reasoning steps
)


In [10]:
# Display the configuration being used for evaluation
print("Configuration:")
print(f"  Model: {cfg_reasoning.model_name}")
print(f"  Dataset: {cfg_reasoning.dataset_name}")
print(f"  Samples: {cfg_reasoning.limit_eval_samples}")
print(f"  Batch size: {cfg_reasoning.batch_size}")
print(f"  Max new tokens: {cfg_reasoning.max_new_tokens}")
print(f"  Output file: {cfg_reasoning.output_file}")

Configuration:
  Model: deepseek-ai/DeepSeek-R1-Distill-Qwen-32B
  Dataset: HuggingFaceH4/MATH-500
  Samples: 500
  Batch size: 2
  Max new tokens: 512
  Output file: math500_results_500.jsonl


In [None]:
# Run the improved evaluation with Raschka's pipeline
# This uses:
# - Proper LaTeX normalization (handles \frac{}, \sqrt{}, etc.)
# - SymPy symbolic comparison (recognizes mathematically equivalent expressions)
# - Robust answer extraction from \boxed{} format with fallback to last number
# - Detailed logging to JSONL file for error analysis

num_correct, num_total, accuracy = evaluate_math500(cfg_reasoning)

Loading model: deepseek-ai/DeepSeek-R1-Distill-Qwen-32B


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 8/8 [00:53<00:00,  6.65s/it]
Evaluating:   0%|          | 0/250 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


MATH-500: 2/500 | ETA: 2h 08m 01s | Correct: 1/2

Evaluating:   0%|          | 1/250 [00:30<2:08:00, 30.85s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


MATH-500: 4/500 | ETA: 2h 05m 27s | Correct: 3/4

Evaluating:   1%|          | 2/250 [01:00<2:05:05, 30.26s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


MATH-500: 6/500 | ETA: 1h 53m 23s | Correct: 4/6

Evaluating:   1%|          | 3/250 [01:22<1:48:54, 26.45s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


In [13]:
# Display final results
print(f"\nFinal Accuracy: {accuracy*100:.2f}%")
print(f"Correct: {num_correct}/{num_total}")

# Example: Load and inspect some results
import json
from pathlib import Path

results_file = Path(cfg_reasoning.output_file)
if results_file.exists():
    print(f"\n--- Sample Results from {results_file} ---")
    with open(results_file, 'r') as f:
        for i, line in enumerate(f):
            if i >= 3:  # Show first 3 examples
                break
            record = json.loads(line)
            print(f"\nExample {record['index']}:")
            print(f"  Problem: {record['problem'][:100]}...")
            print(f"  Ground Truth: {record['ground_truth']}")
            print(f"  Extracted: {record['extracted_answer']}")
            print(f"  Correct: {record['correct']}")


Final Accuracy: 54.00%
Correct: 27/50

--- Sample Results from math500_results.jsonl ---

Example 1:
  Problem: Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the...
  Ground Truth: \left( 3, \frac{\pi}{2} \right)
  Extracted: (3, \frac{\pi}{2})
  Correct: True

Example 2:
  Problem: Define
\[p = \sum_{k = 1}^\infty \frac{1}{k^2} \quad \text{and} \quad q = \sum_{k = 1}^\infty \frac{...
  Ground Truth: p - q
  Extracted: \frac{q}{2}
  Correct: False

Example 3:
  Problem: If $f(x) = \frac{3x-2}{x-2}$, what is the value of $f(-2) +f(-1)+f(0)$? Express your answer as a com...
  Ground Truth: \frac{14}{3}
  Extracted: \dfrac{14}{3}
  Correct: True


: 