
<br>
<font>
<div dir=ltr align=center>
<font color=0F5298 size=7>
    System 2 - Homework 2<br>
<font color=2565AE size=5>
    Spring 2025<br>
<font color=3C99D size=5>
    Inference-time Scaling <br>


### Full Name: Danial Parnian
---

# Assignment Overview (Part 1)

In this assignment, you will explore inference scaling techniques in large language models (LLMs) and evaluate their performance using the Math Benchmark. Throughout the notebook, you will learn about several inference methods, including:

- **Chain-of-Thought (CoT):** A method where the model generates intermediate reasoning steps before providing the final answer.
- **Best-of-n Sampling:** An approach that generates multiple candidate responses and selects the best one based on a scoring function.
- **Beam Search:** A technique that expands several possible sequences simultaneously, choosing the most promising ones based on probability.
- **Self-Refinement:** An iterative process where the model revises its output to improve accuracy and coherence.

The **Math Benchmark** is a suite of challenging mathematical problems designed to test the reasoning and problem-solving capabilities of LLMs. The benchmark includes a variety of questions ranging from basic arithmetic and algebra to more advanced topics such as geometry and calculus. For example, you might be asked to solve an equation like `2x + 5 = 15` or compute the derivative of a function, tasks that assess the model's ability to handle both straightforward and complex mathematical queries.

By the end of this assignment, you will have:
- Gained a deeper understanding of inference time scaling methods in LLMs.
- Compared the effectiveness of different inference techniques using a rigorous math evaluation framework.

Let's dive into the notebook and begin exploring how these methods perform on a challenging set of math problems!


# installing Dependencies

In [1]:
!pip install vllm
!pip install transformers accelerate datasets

from IPython.display import clear_output
clear_output()

* You should use this cell if you're running the notebook on Google Colab. If you're using Kaggle, you don't need to run this cell.

In [2]:
# !pip install --upgrade numpy
# import os
# os.kill(os.getpid(), 9)

## vLLM: Accelerated Inference Engine for LLMs

vLLM is an open-source project designed to optimize the loading and inference of large language models. By leveraging advanced memory management techniques and dynamic batching, vLLM significantly speeds up the inference process, making it easier to deploy and experiment with LLMs even on hardware with limited resources
So we use vLLM to get results faster.

## VLLM Server Setup and Initialization

In this section, we install the required packages, ensure that only one server instance is running, and start the VLLM server using the model `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`.

**Installation and Cleanup:**
- The necessary packages (`vllm`, `transformers`, `accelerate`, and `datasets`) are installed in a cell with hidden output to keep the notebook clean.
- Any previously running VLLM server instances are terminated before starting a new one. This prevents multiple servers from running simultaneously.

**Server Initialization:**
- The server is launched as a background process using `subprocess.Popen`.
- **Initialization Time:**  
  The server typically takes about **1 minute** to fully initialize.
- **GPU Memory Utilization:**  
  Monitor your GPU memory usage. Initially, it will be at **0 GB**, and then it will gradually increase until it reaches approximately **12 GB** when the server is fully up and running.

Please wait until the GPU memory stabilizes around **12 GB** before proceeding to the next steps.


In [3]:
import subprocess
import os

# Kill any running VLLM server instances for the specified model
kill_cmd = "pkill -f 'vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'"
subprocess.run(kill_cmd, shell=True)

# Command to start the VLLM server
cmd = [
    "vllm", "serve", "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    "--port", "8000", "--dtype=half", "--max-model-len", "5192"
]

# Redirect all output to os.devnull to suppress logs
with open(os.devnull, "w") as devnull:
    server_process = subprocess.Popen(cmd, stdout=devnull, stderr=devnull)

print("Server started!")

Server started!


* you can debug last cell if doesn't work right with this cell (if that works you DO NOT run this cell)

In [4]:
# !vllm serve "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"   --port 8000   --dtype=half   --max-model-len 5192

# Helper Functions Overview

This section contains a series of helper functions designed to facilitate the evaluation of mathematical problem solving using the MATH-500 dataset and a local LLM server. These functions handle tasks such as dataset loading, answer extraction, normalization of various mathematical expressions, answer comparison, and result management. Below is an explanation of each group of functions:

---

## Dataset Loading

- **`load_math500_dataset()`**  
  Loads the test split of the MATH-500 dataset from the Hugging Face repository (`HuggingFaceH4/MATH-500`). This dataset provides the math problems and corresponding solutions used for evaluation.

---

## Answer Extraction

- **`extract_answer(response: str) -> Optional[str]`**  
  Searches the provided text for the last occurrence of the LaTeX command `\boxed{...}` and extracts the content within it. This function is essential for retrieving the final answer from the formatted solutions.

---

## Normalization Functions

Normalization functions standardize the format of mathematical expressions to enable accurate comparisons between predicted and correct answers. These functions account for various representations, ensuring that equivalent answers written in different formats are recognized as equal.

- **`normalize_number(num_str: str) -> str`**  
  Cleans and normalizes numeric strings by removing extraneous characters (e.g., commas, currency symbols, and measurement units) and formatting them into a consistent number format.

- **`numerically_equal(str1: str, str2: str) -> bool`**  
  Checks if two numeric strings represent the same value within a small tolerance, accounting for floating point precision issues.

- **`normalize_fraction(fraction_str: str) -> str`**  
  Converts various representations of fractions (with or without braces or using a slash) into a standard LaTeX format: `\frac{numerator}{denominator}`.

- **`normalize_matrix_entry(entry: str) -> str`**  
  Standardizes individual matrix entries, especially handling fractions and slash-separated numbers, to ensure consistency within matrix representations.

- **`normalize_matrix(matrix_str: str) -> str`**  
  Processes a LaTeX matrix (formatted with `\begin{pmatrix}` and `\end{pmatrix}`) by normalizing each row and each entry using the matrix entry normalization.

- **`normalize_algebraic_expression(expr: str) -> str`**  
  Standardizes algebraic expressions by handling coefficients, variables, exponents, and special terms like π (pi). This helps compare algebraic answers regardless of minor formatting differences.

- **`normalize_interval_bound(bound: str) -> str`**  
  Normalizes the boundary of an interval, ensuring that symbols like infinity (`\infty`) and other numeric boundaries are consistently formatted.

- **`normalize_interval(interval_str: str) -> str`**  
  Standardizes an interval provided in LaTeX, ensuring that both bounds are normalized and that the overall format (including brackets) is consistent.

- **`normalize_ordered_tuple(tuple_str: str) -> str`**  
  Normalizes an ordered tuple by splitting its elements and applying answer normalization to each component, ensuring a standard tuple representation.

- **`normalize_answer(answer: str) -> str`**  
  The central normalization function that applies the various normalization steps to a given answer. It cleans up LaTeX formatting, removes unnecessary spaces, and calls the specialized normalization functions to standardize numeric, fractional, algebraic, and other mathematical expressions.

---

## Answer Comparison

- **`compare_answers(correct_answer: str, predicted_answer: Optional[str]) -> bool`**  
  Compares the normalized versions of the correct answer and the predicted answer. This function ensures that answers are compared in a standardized format so that minor differences in formatting do not affect the evaluation outcome.

---

## Result Management Functions

These functions handle saving and analyzing the results of the evaluation process.

- **`load_existing_results(filename: str) -> list[Dict]`**  
  Loads previously saved evaluation results from a JSON file. If the file does not exist, it returns an empty list.

- **`save_result(filename: str, result: Dict)`**  
  Appends a single evaluation result (including problem details, the LLM response, and correctness) to the results file in JSON format.

- **`analyze_results(results: list[Dict])`**  
  Analyzes the evaluation outcomes by summarizing the total number of problems, counting the correct answers, calculating the accuracy, and printing details for any problems that were answered incorrectly.

---

## Main Evaluation and Response Handling

- **`evaluate()`**  
  The primary function that orchestrates the evaluation process:
  - Creates a results directory if it doesn't already exist.
  - Loads the MATH-500 dataset.
  - Iterates over each problem (while skipping already processed ones).
  - Sends the problem text to the local LLM server using `get_llm_response`.
  - Extracts and compares the answers, then saves the result.
  - Finally, it analyzes and prints a summary of the evaluation.

- **`get_llm_response(prompt: str) -> str`**  
  Sends a prompt to the locally running LLM server (via an HTTP POST request to `http://localhost:8000/v1/chat/completions`) and returns the server's response. This function is key to obtaining the model's predicted answer.


In [5]:
import json
import os
import re
from typing import Dict, Optional, Union
from datasets import load_dataset
from tqdm import tqdm
import torch

# Load the MATH-500 dataset
def load_math500_dataset():
    dataset = load_dataset("HuggingFaceH4/MATH-500")["test"]
    return dataset

# Extract the last boxed answer from text
def extract_answer(response: str) -> Optional[str]:
    if not response:
        return None
    start_idx = response.rfind('\\boxed{')
    if start_idx == -1:
        return None
    brace_count = 1
    pos = start_idx + 7  # length of '\boxed{'
    while pos < len(response) and brace_count > 0:
        if response[pos] == '{':
            brace_count += 1
        elif response[pos] == '}':
            brace_count -= 1
        pos += 1
    if brace_count == 0:
        answer = response[start_idx + 7:pos - 1]
        return answer.strip()
    return None

# Normalization and comparison functions (unchanged from original)
def normalize_number(num_str: str) -> str:
    try:
        cleaned = re.sub(r'[,\$\\]|\s*(?:cm|m|kg|ft|in|lb|oz|ml|L)$|\s*\\text{[^}]+}', '', num_str).strip()
        if cleaned.startswith('.'):
            cleaned = '0' + cleaned
        num = float(cleaned)
        if abs(num) < 1 and '.' in cleaned:
            decimal_places = len(cleaned.split('.')[1])
            format_str = f"{{:.{decimal_places}f}}"
            result = format_str.format(num)
        else:
            result = str(num)
        return result
    except:
        return num_str

def numerically_equal(str1: str, str2: str) -> bool:
    try:
        return abs(float(str1) - float(str2)) < 1e-10
    except:
        return False

def normalize_fraction(fraction_str: str) -> str:
    try:
        fraction_str = fraction_str.replace('\\dfrac', '\\frac')
        fraction_str = ''.join(fraction_str.split())
        fraction_str = re.sub(r'\s*\\text{[^}]+}', '', fraction_str)
        mixed_brace = re.match(r'^\\frac(\d+)\{(\d+)\}$', fraction_str)
        if mixed_brace:
            num, den = mixed_brace.groups()
            return f"\\frac{{{num}}}{{{den}}}"
        no_braces = re.match(r'^\\frac(\d+)(\d+)$', fraction_str)
        if no_braces:
            num, den = no_braces.groups()
            return f"\\frac{{{num}}}{{{den}}}"
        if '/' in fraction_str and not any(c in fraction_str for c in '\\{}'):
            num, den = fraction_str.split('/')
            return f"\\frac{{{num.strip()}}}{{{den.strip()}}}"
        standard = re.match(r'^\\frac\{([^{}]+)\}\{([^{}]+)\}$', fraction_str)
        if standard:
            num, den = standard.groups()
            return f"\\frac{{{num}}}{{{den}}}"
    except:
        return fraction_str

def normalize_matrix_entry(entry: str) -> str:
    entry = ''.join(entry.split())
    if '/' in entry and not any(c in entry for c in '\\{}'):
        if entry.startswith('-'):
            num, den = entry[1:].split('/')
            return f"-{num.strip()}/{den.strip()}"
        else:
            num, den = entry.split('/')
            return f"{num.strip()}/{den.strip()}"
    entry = entry.replace('\\dfrac', '\\frac')
    frac_match = re.match(r'^(-)?\\frac\{(\d+)\}\{(\d+)\}$', entry)
    if frac_match:
        sign, num, den = frac_match.groups()
        sign = sign if sign else ''
        return f"{sign}{num}/{den}"
    return entry

def normalize_matrix(matrix_str: str) -> str:
    try:
        matrix_str = ''.join(matrix_str.split())
        match = re.match(r'^\\begin\{pmatrix\}(.*?)\\end\{pmatrix\}$', matrix_str)
        if not match:
            return matrix_str
        content = match.group(1)
        rows = content.split('\\\\')
        normalized_rows = []
        for row in rows:
            if '&' in row:
                entries = [normalize_matrix_entry(entry) for entry in row.split('&')]
            else:
                entries = [normalize_matrix_entry(row)]
            normalized_rows.append('&'.join(entries))
        result = "\\begin{pmatrix}" + "\\\\".join(normalized_rows) + "\\end{pmatrix}"
        return result
    except:
        return matrix_str

def normalize_algebraic_expression(expr: str) -> str:
    try:
        expr = ''.join(expr.split())
        monomial_match = re.match(r'^(-?\d*\.?\d*)?([a-zA-Z])(?:\^(-?\d+))?$', expr)
        if monomial_match:
            coeff, var, exp = monomial_match.groups()
            coeff = coeff if coeff and coeff not in ['+', '-'] else ('1' if not coeff else '-1')
            exp = exp if exp else '1'
            if coeff == '1' and exp == '1':
                return var
            elif coeff == '1':
                return f"{var}^{exp}"
            elif coeff == '-1' and exp == '1':
                return f"-{var}"
            elif coeff == '-1':
                return f"-{var}^{exp}"
            elif exp == '1':
                return f"{coeff}{var}"
            else:
                return f"{coeff}{var}^{exp}"
        pi_term_match = re.match(r'^(-?\d*\.?\d*)\\?pi$', expr)
        if pi_term_match:
            coeff = pi_term_match.group(1)
            if not coeff or coeff == '-':
                coeff = '-1' if coeff == '-' else '1'
            return f"{coeff}\\pi"
        frac_pi_match = re.match(r'^\\frac{([^{}]+)}{([^{}]+)}\\?pi$', expr)
        if frac_pi_match:
            num, den = frac_pi_match.groups()
            return f"\\frac{{{num}}}{{{den}}}\\pi"
        frac_match = re.match(r'^\\frac{([^{}]+)}{([^{}]+)}$', expr)
        if frac_match:
            num, den = frac_match.groups()
            return f"\\frac{{{num}}}{{{den}}}"
    except:
        return expr.lower()

def normalize_interval_bound(bound: str) -> str:
    if '\\infty' in bound:
        sign = '-' if bound.startswith('-') else ''
        return f"{sign}\\infty"
    return normalize_answer(bound) or bound

def normalize_interval(interval_str: str) -> str:
    try:
        interval_str = ''.join(interval_str.split())
        match = re.match(r'^\\left?([\[\(])(.*?),(.*?)\\right?([\]\)])$', interval_str)
        if not match:
            match = re.match(r'^([\[\(])(.*?),(.*?)([\]\)])$', interval_str)
            if not match:
                return interval_str
        left_bracket, left_bound, right_bound, right_bracket = match.groups()
        norm_left = normalize_interval_bound(left_bound)
        norm_right = normalize_interval_bound(right_bound)
        return f"\\left{left_bracket}{norm_left},{norm_right}\\right{right_bracket}"
    except:
        return interval_str

def normalize_ordered_tuple(tuple_str: str) -> str:
    try:
        tuple_str = tuple_str.replace('\\dfrac', '\\frac')
        tuple_str = tuple_str.replace('\\left', '').replace('\\right', '')
        tuple_str = re.sub(r'\\?\s+', '', tuple_str)
        inner = tuple_str.strip('()')
        parts = inner.split(',')
        normalized_parts = [normalize_answer(part.strip()) for part in parts if normalize_answer(part.strip())]
        return f"({','.join(normalized_parts)})"
    except:
        return None

def normalize_answer(answer: str) -> str:
    if answer is None:
        return ""
    answer = re.sub(r'\\text{[^}]+(?:inches|feet|meters|cm|m|kg|ft|in|lb|oz|ml|L|per|second|minute|hour)[^}]*}', '', answer)
    answer = re.sub(r'(?<!\\)\s+', '', answer)
    ordered_pair_match = re.match(r'^(?:\\left)?\((.*?)(?:\\right)?\)$', answer)
    if ordered_pair_match:
        content = ordered_pair_match.group(1)
        parts = content.split(',')
        normalized_parts = [normalize_answer(part) for part in parts if normalize_answer(part)]
        return f"({','.join(normalized_parts)})"
    answer = ''.join(answer.split())
    if not answer:
        return None
    pm_match = re.match(r'^(.*?)(?:\\pm|-)(.*?)$', answer)
    if pm_match:
        left, right = pm_match.groups()
        norm_left = normalize_answer(left) if left else ""
        norm_right = normalize_answer(right) if right else ""
        if norm_left or norm_right:
            return f"{norm_left}\\pm{norm_right}"
    trig_match = re.match(r'^\\(?:sin|cos|tan|cot|sec|csc)\s*([a-zA-Z])$', answer)
    if trig_match:
        variable = trig_match.group(1)
        func_name = re.match(r'^\\(.*?)(?:\s|$)', answer).group(1)
        return f"\\{func_name}{variable}"
    text_match = re.match(r'^(?:\\text{)?([A-Za-z]+)(?:})?$', answer)
    if text_match:
        return text_match.group(1).lower()
    if (answer.startswith('\\left[') or answer.startswith('\\left(') or
        answer.startswith('[') or answer.startswith('(')) and \
       (answer.endswith('\\right]') or answer.endswith('\\right)') or
        answer.endswith(']') or answer.endswith(')')):
        return normalize_interval(answer)
    if answer.startswith('\\begin{pmatrix}') and answer.endswith('\\end{pmatrix}'):
        return normalize_matrix(answer)
    answer = answer.replace('\\dfrac', '\\frac')
    if '\\frac' in answer or '/' in answer:
        return normalize_fraction(answer)
    neg_sqrt_match = re.match(r'^-\\sqrt\{?(\d+)\}?$', answer)
    if neg_sqrt_match:
        num = neg_sqrt_match.group(1)
        return f"-\\sqrt{{{num}}}"
    sqrt_match = re.match(r'^(\d*)?\\sqrt\{?(\d+)\}?$', answer)
    if sqrt_match:
        coeff, num = sqrt_match.groups()
        coeff = coeff if coeff else '1'
        return f"\\sqrt{{{num}}}" if coeff == '1' else f"{coeff}\\sqrt{{{num}}}"
    sqrt_with_coeff_match = re.match(r'^(\d+)\\sqrt\{?(\d+)\}?$', answer)
    if sqrt_with_coeff_match:
        coeff, num = sqrt_with_coeff_match.groups()
        return f"{coeff}\\sqrt{{{num}}}"
    base_match = re.match(r'^(\d+)(?:_\{?(\d+)\}?|_(\d+))$', answer)
    if base_match:
        number, base1, base2 = base_match.groups()
        base = base1 if base1 else base2
        return f"{number}_{base}"
    percent_match = re.match(r'^(\d+(?:\.\d*)?)\s*\\?%$', answer)
    if percent_match:
        return normalize_number(percent_match.group(1))
    unit_match = re.match(r'^(\d+(?:\.\d*)?)\s*(?:(?:\\[,\s])|,)?\s*(?:\\\\)?(?:\\text{(\w+)}|\\?(?:cm|m|kg|ft|in|lb|oz|ml|L))$', answer)
    if unit_match:
        return normalize_number(unit_match.group(1))
    currency_match = re.match(r'^\\?\$?([\d,]+\.?\d*)$', answer)
    if currency_match:
        return normalize_number(currency_match.group(1))
    if re.match(r'^-?[\d,]+$', answer):
        return normalize_number(answer)
    unit_match = re.match(r'^(-?[\d,]+(?:\.\d*)?)\s*(?:\\(?:mbox|text|hbox|displaystyle)\{[^}]+\})?(?:\^?\d)?$', answer)
    if unit_match:
        return normalize_number(unit_match.group(1))
    mc_match = re.match(r'^\\text{\(?([A-Za-z])\)?}$|^\(?([A-Za-z])\)?$', answer)
    if mc_match:
        return (mc_match.group(1) or mc_match.group(2)).lower()
    degree_match = re.match(r'^(-?[\d,]+(?:\.\d*)?)\s*(?:(?:\^?\\circ)|(?:{\\circ})|(?:°))?$', answer)
    if degree_match:
        return normalize_number(degree_match.group(1))
    answer = re.sub(r'\\text{([^{}]+)}', r'\1', answer)
    try:
        return normalize_algebraic_expression(answer)
    except:
        pass
    answer = answer.replace('\\left', '').replace('\\right', '')
    answer = answer.replace('\\(', '(').replace('\\)', ')')
    answer = answer.replace('\\[', '[').replace('\\]', ']')
    answer = answer.replace('\\{', '{').replace('\\}', '}')
    answer = re.sub(r'\\sqrt\{?(\d+)\}?', r'\\sqrt{\1}', answer)
    answer = re.sub(r'\\sqrt{([^{}]+)}', r'\\sqrt\1', answer)
    if re.match(r'^\d+\\%$', answer) or re.match(r'^\d+$', answer):
        answer = re.sub(r'\\%$', '', answer)
    answer = re.sub(r'\\text{([^{}]+)}', r'\1', answer)
    while len(answer) >= 2 and answer[0] == '{' and answer[-1] == '}':
        if '\\frac' in answer:
            break
        answer = answer[1:-1]
    return answer.lower() if answer else None

def compare_answers(correct_answer: str, predicted_answer: Optional[str]) -> bool:
    if predicted_answer is None:
        return False
    if numerically_equal(correct_answer, predicted_answer):
        return True
    normalized_correct = normalize_answer(correct_answer)
    normalized_predicted = normalize_answer(predicted_answer)
    if not normalized_correct or not normalized_predicted:
        return False
    if normalized_correct == "" and normalized_predicted == "":
        return False
    if ('\\left[' in normalized_correct or '\\left(' in normalized_correct) and \
       ('\\left[' in normalized_predicted or '\\left(' in normalized_predicted):
        return normalized_correct == normalized_predicted
    return normalized_correct == normalized_predicted

# Load existing results
def load_existing_results(filename: str) -> list[Dict]:
    try:
        with open(filename, 'r') as f:
            return json.load(f)
    except FileNotFoundError:
        return []

# Save a single result
def save_result(filename: str, result: Dict):
    results = load_existing_results(filename)
    results.append(result)
    with open(filename, 'w') as f:
        json.dump(results, f, indent=2)

# Analyze and print results
def analyze_results(results: list[Dict]):
    total = len(results)
    correct = sum(1 for r in results if r['is_correct'])
    accuracy = correct / total if total > 0 else 0
    print("\n=== Results Summary ===")
    print(f"Total problems: {total}")
    print(f"Correct answers: {correct}")
    print(f"Accuracy: {accuracy:.2%}")
    print("\n=== Incorrect Problems ===")
    for r in results:
        if not r['is_correct']:
            print(f"Problem {r['index']}:")
            print(f"Expected: {r['correct_answer']}")
            print(f"Predicted: {r['predicted_answer']}")
            print("---")

# Main evaluation function
def evaluate():
    os.makedirs("results", exist_ok=True)
    results_file = "evaluation_results_math500_deepseek.json"
    dataset = load_math500_dataset()
    existing_results = load_existing_results(results_file)
    processed_indexes = {result['index'] for result in existing_results}
    cnt = 0
    t=0
    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed_indexes:
            continue
        t += 1
        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])  # Extract from 'solution', not 'answer'
        response = get_llm_response(problem_text)
        predicted_answer = extract_answer(response)
        is_correct = compare_answers(correct_answer, predicted_answer)
        result = {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        }
        save_result(results_file, result)
        if is_correct:
          cnt += 1
        print(f"cnt :  {cnt} idx: {t}")
    final_results = load_existing_results(results_file)
    analyze_results(final_results)



## LLM Query Function

* This Python function sends prompts to a locally-hosted LLM API and returns the generated response
* you can change max_tokens and temperature as you want

In [6]:
import requests
def get_llm_response(prompt, max_tokens=1700, temperature=0.3):
    """ get response from Qwen model """
    url = "http://localhost:8000/v1/chat/completions"

    payload = {
        "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        "messages": [
            {
                "role": "user",
                "content": prompt
            }

        ],
    "max_tokens": max_tokens,
    "temperature": temperature
    }
    response = requests.post(url, json=payload)
    return response.json()['choices'][0]['message']['content'].strip()

## Test Prompt: Evaluating an Integral

In this cell, we define a new math benchmark question to verify that the LLM server is correctly set up and that responses can be retrieved.

**Question:**  
What is the value of the integral  
$$\int_0^1 x^2\,dx$$  

**Expected Answer:**  
$$\boxed{\frac{1}{3}}$$

The cell sends this prompt to the LLM server using the `get_llm_response` function and prints the response. This helps confirm that the integration between the notebook and the LLM server is working properly.

In [7]:
os.environ["LOG_LEVEL"] = "WARNING"
# Define a new math benchmark question for testing
question = "What is the value of the integral $$\\int_0^1 x^2 dx$$ answer it directly in one sentence?"
# Real answer: \boxed{\frac{1}{3}}

# Get response from the LLM server using the provided get_llm_response function
response = get_llm_response(question)

# Print the response to verify that the setup is working correctly
print("Response:", response)


Response: To evaluate the integral of \( x^2 \) from 0 to 1, I start by finding the antiderivative of \( x^2 \), which is \( \frac{x^3}{3} \).

Next, I apply the Fundamental Theorem of Calculus by substituting the upper limit 1 into the antiderivative and subtracting the value of the antiderivative at the lower limit 0.

This results in \( \frac{1^3}{3} - \frac{0^3}{3} = \frac{1}{3} \).

Therefore, the value of the integral is \( \frac{1}{3} \).
</think>

The value of the integral is \(\boxed{\dfrac{1}{3}}\).


# Customizable CoT Prompt Template
* modify cot prompt then evaluate on math benchmark


In [8]:
import requests

# Define the system prompt
COT_PROMPT = '''You are solving mathematics problems.

Please think step by step.

Important: Always end your solution with the final answer in this format:

\\[
\\boxed{your_answer_here}
\\]

The entire answer should be contained completely within the \\boxed{} command.'''

def get_COT_response(problem):
    prompt = COT_PROMPT + "\n" + problem
    url = "http://localhost:8000/v1/chat/completions"

    payload = {
        "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        "messages": [
            {
                "role": "user",
                "content": prompt
            }

        ],
    "max_tokens": 1900,
    "temperature": 0.3
    }
    response = requests.post(url, json=payload)
    return response.json()['choices'][0]['message']['content'].strip()

# Evaluate CoT
* modify response generation part to evalute this method.

In [9]:
def evaluate_cot():
    os.makedirs("results", exist_ok=True)
    results_file = "evaluation_results_math500_deepseek_cot.json"
    dataset = load_math500_dataset()
    existing_results = load_existing_results(results_file)
    processed_indexes = {result['index'] for result in existing_results}
    cnt = 0
    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed_indexes:
            continue
        if idx >= 30:
          break
        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])
        ##########################################################
        response = get_COT_response(problem_text)
        predicted_answer = extract_answer(response)
        ##########################################################
        is_correct = compare_answers(correct_answer, predicted_answer)
        result = {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        }
        save_result(results_file, result)
        if is_correct:
          cnt += 1
        print(f"corrects :  {cnt} idx: {idx}")
    final_results = load_existing_results(results_file)
    analyze_results(final_results)

In [None]:
evaluate_cot()

=== Results Summary ===

Total problems: 30

Correct answers: 15

Accuracy: 50.00%

Runtime: 10:06

## Best-of-N 

The **Best-of-N** approach improves math problem-solving by generating *N* solutions and selecting the one with the highest average token log-likelihood. Each solution is crafted using a prompt that encourages step-by-step reasoning and includes a formatted answer. The final selected response is both reliable and well-presented.

### Steps
1. **Generate**: Produce *N* responses using a structured guiding prompt.
2. **Evaluate**: Compute the average log-likelihood for each response based on token probabilities.
3. **Select**: Identify and choose the response with the highest score.

This method ensures a statistically robust and clearly formatted solution.


## Verification Methods in Best‑of‑N Evaluation

When sampling multiple candidate solutions for each math problem, we need a reliable way to choose the single best answer. We support two complementary approaches:

### Log‑Probability Scoring

**Concept**  
Each generated solution comes with token‑level log‑likelihoods. By averaging these values across all tokens in the response, we obtain a single score reflecting how “confident” the model is in that entire output.

**Why Use It**  
- **Self‑Contained & Fast**: Requires no external calls or additional models.  
- **Cost‑Effective**: Purely internal computation, so it adds negligible expense.  

**Limitations**  
- A high likelihood does not always imply a correct or well‑reasoned solution, especially on complex math problems.

---

### LLM‑Based Verification

**Concept**  
Instead of trusting raw likelihoods, we hand all sampled responses off to a second, high‑quality language model (e.g. Gemini Mini). That model reads the original problem and the list of candidate boxed answers, then selects the one it judges to be correct.

**Why Use It**  
- **Deeper Reasoning**: A dedicated verifier can compare alternative answers and catch subtle mistakes.  
- **Improved Robustness**: Mitigates cases where a flawed but high‑probability output would otherwise be chosen.

**Trade‑Offs**  
- **Slower**: Requires additional API calls and round‑trip latency.  
- **External Cost**: Incurs usage fees on the verification model.

---

### Balancing Speed, Cost, and Accuracy

By exposing a simple toggle between these two methods, you can:

- **Optimize for Speed**: Use log‑prob scoring when you need rapid, low‑cost evaluation.  
- **Optimize for Accuracy**: Use LLM‑based verification when correctness is paramount.  

Experiment on your dataset to find the right trade‑off for your needs.

In [8]:
from openai import OpenAI

# Initialize OpenAI client
client = OpenAI(
    api_key="YOUR_API_KEY_HERE",  # Replace with your actual API key
    base_url="https://generativelanguage.googleapis.com/v1beta/"
)

# Use the cheapest Gemini model
LLM_API_MODEL = "gemini-2.0-flash-lite"

def get_api_response(prompt: str) -> str:
    """
    Send `prompt` to Gemini and return its reply.
    """
    resp = client.chat.completions.create(
        model=LLM_API_MODEL,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.0,
        max_tokens=1024,
    )
    return resp.choices[0].message.content

# test the model
print(get_api_response("hi. how you doing?"))

I'm doing well, thank you for asking! As a large language model, I don't experience emotions, but I am functioning optimally and ready to assist you. How can I help you today?



In [9]:
import re
import time
import math
import heapq
import requests
from typing import List, Optional


SYSTEM_PROMPT = '''You are solving mathematics problems.

Please think step by step.

Important: Always end your solution with the final answer in this format:

\\[
\\boxed{your_answer_here}
\\]

The entire answer should be contained completely within the \\boxed{} command.'''


def verify_with_gemini(problem: str, outputs: List[str]) -> Optional[str]:
    """
    Given the original problem and a list of candidate full-response texts,
    asks a Gemini API to pick the correct final answer (boxed).
    """
    # Deduplicate `outputs` into a `unique_answers: List[str]` by extracting
    unique_answers = []
    for output in outputs:
        answer = extract_answer(output)
        if answer is None:
            answer = "<no_boxed_answer>"
        if answer not in unique_answers:
            unique_answers.append(answer)

    # Build `options` as numbered lines of the form "1. \\boxed{...}" from `unique_answers`.
    options = "\n".join([f"{i+1}. \\boxed{{{ans}}}" for i, ans in enumerate(unique_answers)])
    # Compose `verify_prompt` with the problem and options.
    verify_prompt = f"""Given the problem: {problem}\n\n
Here are the candidate answers:\n{options}\n\n Choose the correct answer. If none are correct, choose 0.""" + '''Important: Always end your solution with the final answer in this format:

\\[
\\boxed{your_answer_here}
\\]

The entire answer should be contained completely within the \\boxed{} command.'''

    # Call `get_api_response(verify_prompt)` and strip whitespace → `chosen`.
    chosen = get_api_response(verify_prompt).strip()

    # Return `extract_answer(chosen)` to normalize formatting.
    return extract_answer(chosen)


def best_of_n_response(
    problem: str,
    N: int = 5,
    use_logprob: bool = True,
    model: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    port: int = 8000,
    temp: float = 0.2
) -> Optional[str]:
    """
    Run N samples on your VLLM server, then:
    - if use_logprob: pick the candidate with highest avg log-prob
    - else: hand off all N outputs to Gemini to choose the best boxed answer
    """
    url = f"http://localhost:{port}/v1/chat/completions"
    prompt = SYSTEM_PROMPT + "\n" + problem

    samples = []
    for _ in range(N):
        # Build the `payload` dict with model, messages, max_tokens, temperature, and logprobs.
        payload = {
            "model": model,
            "messages": [
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            "max_tokens": 1900,
            "temperature": temp,
            "logprobs": True,
            "top_logprobs": 1
        }
        # POST to `requests.post(url, json=payload)` and parse `.json()` → `resp`.
        response = requests.post(url, json=payload)
        resp = response.json()
        # Extract `text` from `resp['choices'][0]['message']['content']`.
        text = resp['choices'][0]['message']['content']

        # Compute `avg_lp` by collecting all `choice['logprobs']['content'][*]['logprob']` values.
        logprobs = resp['choices'][0]['logprobs']['content']
        avg_lp = sum([item['logprob'] for item in logprobs]) / len(logprobs) if logprobs else float('-inf')

        # Append `{"text": text, "avg_lp": avg_lp}` to `samples`.
        samples.append({"text": text, "avg_lp": avg_lp})

    if use_logprob:
        # Select the `sample` with the highest `avg_lp`.
        best = max(samples, key=lambda x: x["avg_lp"])
        # Return `extract_answer(best["text"])`.
        return extract_answer(best["text"])
    else:
        # Gather `outs = [s["text"] for s in samples]`.
        outs = [s["text"] for s in samples]
        # Return `verify_with_gemini(problem, outs)`.
        return verify_with_gemini(problem, outs)


# Evaluate best of n

* modify response generation part to evalute this method.

In [10]:
def evaluate_best_of_n(use_logprob: bool = True, N: int = 3):
    os.makedirs("results", exist_ok=True)
    results_file = (
        "evaluation_results_math500_deepseek_best_of_n_logprob.json"
        if use_logprob else
        "evaluation_results_math500_deepseek_best_of_n_gpt.json"
    )

    dataset = load_math500_dataset()
    existing = load_existing_results(results_file)
    seen = {r['index'] for r in existing}
    correct = 0

    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in seen or idx >= 30:
            continue

        prob = item['problem']
        true_ans = extract_answer(item['solution'])
        pred_ans = best_of_n_response(prob, N=N, use_logprob=use_logprob)
        is_corr = compare_answers(true_ans, pred_ans)

        save_result(results_file, {
            "index": idx,
            "problem": prob,
            "correct_answer": true_ans,
            "predicted_answer": pred_ans,
            "is_correct": is_corr
        })
        if is_corr:
            correct += 1
        print(f"corrects: {correct} / {idx+1}")

    analyze_results(load_existing_results(results_file))


In [None]:
# ─── Example calls ─────────────────────────────────────────────────────────
evaluate_best_of_n(use_logprob=True,  N=3)

=== Results Summary ===

Total problems: 30

Correct answers: 15

Accuracy: 50.00%

Runtime: 33:57

In [None]:
evaluate_best_of_n(use_logprob=False, N=3)

=== Results Summary ===

Total problems: 30

Correct answers: 19

Accuracy: 63.33%

Runtime: 36:13

## Beam Search

This cell implements a beam search strategy for generating candidate reasoning chains. The method generates multiple continuations at each reasoning step, scoring each candidate based on its average token log-likelihood. By retaining and expanding only the top candidates, the approach efficiently searches for the most promising chain-of-thought that leads to the final answer in the required format.

**Key Components:**

- **Model Invocation & Token Scoring:**  
  The `call_qwen_model_raw` function sends requests to a local Qwen model endpoint using step-specific prompts. It returns generated text together with the average token log-probability, which is used as a quality metric.

- **Candidate Representation:**  
  The `BeamCandidate` class encapsulates a reasoning chain. It stores the generated text (sequence), cumulative log-probability, per-step scores, token count, and a finished flag (indicating if the candidate contains the final answer).

- **Step-wise Reasoning Generation:**  
  The `generate_reasoning_steps` function creates multiple candidate continuations for each reasoning step. Different prompts guide the generation for understanding the problem, planning a strategy, and producing the final answer (which is always enclosed in a `\boxed{}` block).

- **Beam Search Process:**  
  The `beam_search` function expands candidate chains over several steps. At each step, candidates are updated by appending the new reasoning text and averaging the log-probabilities from all tokens(you can use num_token now). Only the top candidates (based on cumulative score) are retained for further expansion.

- **Final Answer Extraction:**  
  The `run_qwen_beam_search` function initializes the prompt with the problem statement, runs the beam search, and extracts the final answer from the best candidate if it is complete.

This structured approach ensures efficient exploration of possible reasoning paths while focusing on the most promising ones to arrive at the final answer in the expected format.


In [13]:
import os
import requests
from typing import Optional, List
from tqdm import tqdm


def score_with_gpt(problem: str, reasoning_step: str) -> float:
    """
    Ask a high‑quality LLM (via get_api_response) to rate the given
    reasoning step on a 0–1 scale. Returns the numeric score.
    """
    prompt = (
        "You are a rigorous math reasoning evaluator.\n\n"
        f"Problem:\n{problem}\n\n"
        "Candidate reasoning step:\n"
        f"\"\"\"\n{reasoning_step}\n\"\"\"\n\n"
        "On a scale from 0 (completely incorrect) to 1 (perfectly correct), "
        "rate how valid and useful this step is toward solving the problem. "
        "Reply with only a number between 0 and 1."
    )
    # Call get_api_response(prompt), strip the result
    resp = get_api_response(prompt).strip()
    # Parse float(resp), fallback to 0.0 on ValueError, and return it
    try:
        return float(resp)
    except ValueError:
        return 0.0


def call_qwen_model_raw(prompt: str, step_num: int, temperature: float = 0.3):
    """
    Sends a request to the local Qwen endpoint and returns the generated text
    along with the average token log-probability and token count.
    """
    max_tokens = {1: 500, 2: 800, 3: 1700}.get(step_num, 500)
    url = "http://localhost:8000/v1/chat/completions"
    payload = {
        "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": max_tokens,
        "temperature": temperature,
        "logprobs": True,
    }
    # resp = requests.post(url, json=payload).json()
    resp = requests.post(url, json=payload)
    # Extract text = resp['choices'][0]['message']['content'].strip()
    text = resp.json()['choices'][0]['message']['content'].strip()
    # Gather token_logprobs from resp['choices'][*]['logprobs']['content'][*]['logprob']
    token_logprobs = resp.json()['choices'][0]['logprobs']['content']
    # Compute avg_logprob and num_token, then return (text, avg_logprob, num_token)
    avg_logprob = sum([item['logprob'] for item in token_logprobs]) / len(token_logprobs) if token_logprobs else float('-inf')
    num_token = len(token_logprobs)

    return text, avg_logprob, num_token


class BeamCandidate:
    def __init__(
        self,
        sequence: str,
        cumulative_log_prob: float,
        step_scores: List[float],
        finished: bool,
        num_token: int
    ):
        self.sequence = sequence
        self.cumulative_log_prob = cumulative_log_prob
        self.step_scores = step_scores
        self.finished = finished
        self.num_token = num_token

    def __repr__(self):
        return (
            f"BeamCandidate(score={self.cumulative_log_prob:.3f}, "
            f"finished={self.finished}, sequence=[...])"
        )


def generate_reasoning_steps(
    context: str,
    step_num: int,
    top_k: int,
    use_logprob: bool = True,
    problem: Optional[str] = None
):
    """
    Generate top_k candidate continuations for the current reasoning step.
    """
    candidates = []

    # Build `suffix` based on step_num (1=understand, 2=plan, 3=solve)
    suffixes = {
        1: "\n\nFirst, I'll understand the problem by identifying what's given and what's being asked:\n",
        2: "\n\nNow, I'll plan my solution strategy by breaking down the steps:\n",
        3: "\n\nFinally, I'll solve step-by-step and provide the final answer in a boxed format:\n"
    }
    suffix = suffixes.get(step_num, "\n\nNext, I'll continue my reasoning:\n")

    for i in range(top_k):
        # prompt = context + suffix
        prompt = context + suffix
        # output, avg_token_prob, num_token = call_qwen_model_raw(prompt, step_num)
        output, avg_token_prob, num_token = call_qwen_model_raw(prompt, step_num)

        # if use_logprob: score = avg_token_prob else: assert problem, score = score_with_gpt(problem, output)
        if use_logprob:
            score = avg_token_prob
        else:
            assert problem is not None, "Problem must be provided when not using logprob"
            score = score_with_gpt(problem, output)

        # finished = "\\boxed{" in output
        finished = "\\boxed{" in output
        # candidates.append((output.strip(), score, num_token, finished))
        candidates.append((output.strip(), score, num_token, finished))

    return candidates


def beam_search(
    init_problem_prompt: str,
    beam_width: int = 3,
    max_steps: int = 3,
    top_k: int = 2,
    use_logprob: bool = True
):
    """
    Beam search over reasoning steps. If use_logprob=False, uses GPT verifier
    to score each node instead of token log-prob.
    """
    # Extract `problem` from init_problem_prompt
    problem = init_problem_prompt.split("\n\n")[0].strip()

    # initial = BeamCandidate(sequence=init_problem_prompt, cumulative_log_prob=0.0, step_scores=[], finished=False, num_token=0)
    initial = BeamCandidate(
        sequence=init_problem_prompt,
        cumulative_log_prob=0.0,
        step_scores=[],
        finished=False,
        num_token=0
    )
    # beams = [initial]
    beams = [initial]

    for step in range(1, max_steps + 1):
        new_beams = []

        # For each cand in beams:
        for cand in beams:
            # If this candidate already has a boxed answer, keep it
            if cand.finished:
                new_beams.append(cand)
                continue

            # Generate new candidates from this state
            step_cands = generate_reasoning_steps(cand.sequence, step, top_k, use_logprob, problem)

            for text, score, n_tok, finished in step_cands:
                seq = cand.sequence + "\n" + text
                total_tokens = cand.num_token + n_tok

                # Calculate the new cumulative log probability
                if total_tokens > 0:
                    cum = ((cand.cumulative_log_prob * cand.num_token) + score * n_tok) / total_tokens
                else:
                    cum = score

                new_beams.append(BeamCandidate(
                    seq, cum, cand.step_scores + [score], finished, total_tokens
                ))

        # Sort candidates by score (higher is better) and keep only the top beam_width
        new_beams.sort(key=lambda x: x.cumulative_log_prob, reverse=True)
        beams = new_beams[:beam_width]

        # If all candidates have finished, we can stop early
        if all(beam.finished for beam in beams):
            break

    # Select the best finished candidate, or the best overall if none finished
    finished = [b for b in beams if b.finished]
    best = max(finished, key=lambda x: x.cumulative_log_prob) if finished else beams[0]
    return best


def run_qwen_beam_search(
    problem: str,
    beam_width: int,
    max_steps: int,
    top_k: int,
    log_level,
    use_logprob: bool = True
):
    """
    Performs beam search and extracts the final boxed answer.
    """
    prompt = f"Consider this problem:\n{problem}\n\nI'll solve this step by step."
    best = beam_search(prompt, beam_width, max_steps, top_k, use_logprob)

    if best.finished:
        ans = extract_answer(best.sequence)
        print(f"\nExtracted Final Answer: {ans}")
        return ans
    else:
        print("No final answer found.")
        return None


# Evaluate beam search
* modify response generation part to evalute this method.

In [14]:
def evaluate_beam_search(use_logprob: bool = True):
    """
    Evaluate beam search on MATH‑500, toggling between log‑prob scoring
    and GPT/Gemini–based verification for each reasoning node.
    """
    os.makedirs("results", exist_ok=True)

    suffix = "logprob" if use_logprob else "gpt"
    results_file = f"evaluation_results_math500_deepseek_beam_search_{suffix}.json"

    dataset = load_math500_dataset()
    existing_results = load_existing_results(results_file)
    processed_indexes = {r['index'] for r in existing_results}

    cnt = 0
    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed_indexes or idx >= 30:
            continue

        problem_text   = item['problem']
        correct_answer = extract_answer(item['solution'])

        # Run beam search with the desired scoring method
        response = run_qwen_beam_search(
            problem     = problem_text,
            beam_width  = 3,
            max_steps   = 3,
            top_k       = 2,
            log_level   = 1,              # existing parameter
            use_logprob = use_logprob     # True = token log‑prob, False = GPT verifier
        )
        predicted_answer = response

        is_correct = compare_answers(correct_answer, predicted_answer)
        save_result(results_file, {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        })

        if is_correct:
            cnt += 1
        print(f"corrects: {cnt} / {idx+1}")

    final_results = load_existing_results(results_file)
    analyze_results(final_results)



In [None]:
# Example usage:
evaluate_beam_search(use_logprob=True)   # pick by avg token log‑prob

=== Results Summary ===

Total problems: 30

Correct answers: 16

Accuracy: 53.33%

Runtime: 1:26:59

In [None]:
evaluate_beam_search(use_logprob=False)  # pick by GPT/Gemini verification

=== Results Summary ===

Total problems: 30

Correct answers: 20

Accuracy: 66.67%

Runtime: 1:34:34

## Self-Refinement 

This cell implements a self-refinement approach to solving math problems. Initially, it generates a solution using a fixed system prompt that enforces a step-by-step reasoning process and a final answer format enclosed in `\boxed{}`. Then, through iterative feedback, the model is asked to analyze its own output and refine it if necessary. This loop ensures that the final answer is both correct and clearly formatted.


In [17]:
import re
import requests

SYSTEM_PROMPT = '''You are solving mathematics problems.

Please think step by step.

Important: Always end your solution with the final answer in this format:

\\[
\\boxed{your_answer_here}
\\]

The entire answer should be contained completely within the \\boxed{} command.'''


def generate_content(prompt: str) -> str:
    """
    Sends `prompt` to the local Qwen endpoint and returns the generated text.
    """
    url = "http://localhost:8000/v1/chat/completions"
    payload = {
        "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 1504,
        "temperature": 0.3,
    }
    # Send HTTP POST to `url` with `payload`, parse the JSON response
    response = requests.post(url, json=payload)
    resp = response.json()
    # Extract `content` from `resp['choices'][0]['message']['content']` and strip whitespace
    content = resp['choices'][0]['message']['content'].strip()
    # Return the resulting string
    return content


def self_refine(problem: str, max_iter: int = 2) -> Optional[str]:
    """
    Iteratively refines the model’s output on `problem` using feedback loops.
    """
    # Build initial `prompt` by concatenating SYSTEM_PROMPT and `problem`
    prompt = f"{SYSTEM_PROMPT}\n\nProblem: {problem}"
    # Call `generate_content(prompt)` to get `current_output`
    current_output = generate_content(prompt)

    for iteration in range(max_iter):
        # Construct `feedback_prompt`
        feedback_prompt = (
            f"Imagine you are a proof reader.\n"
            f"PROBLEM:\n{problem}\n\n"
            f"CURRENT SOLUTION TO EVALUATE:\n{current_output}\n\n"
            f"Your job is to evaluate this solution. No need to solve the problem! just analyze and provide feedback for the current solution.\n"
            f"INSTRUCTIONS:\n"
            f"1. Check if the solution has CORRECT mathematics\n"
            f"2. Verify the final answer is inside \\boxed{{}} format\n"
            f"3. Respond using EXACTLY this template:\n\n"
            f"[FEEDBACK]\n"
            f"Brief critique of mathematical correctness\n\n"
            f"[REFINEMENT_NEEDED]\n"
            f"yes OR no (only one word)\n\n"
        )

        # Call `generate_content(feedback_prompt)` → `feedback_response`
        feedback_response = generate_content(feedback_prompt)

        # Determine `refinement_needed` (default True, set False if flag is "no")
        refinement_needed = True

        # Pattern 1: Standard format with headers
        match = re.search(r"\[(?i:feedback)\](.*?)\[(?i:refinement_needed)\](.*)",
                         feedback_response, re.DOTALL)
        if match:
            feedback = match.group(1).strip()
            refinement_flag = match.group(2).strip().lower()

            # Check if refinement is needed
            if refinement_flag in ["no", "no."]:
                refinement_needed = False
        else:
            # Pattern 2: Look for yes/no statements about refinement
            no_patterns = [
                r"refinement(?:\s+is)?\s+(?:not|NOT)\s+needed",
                r"no refinement(?:\s+is)?\s+needed",
                r"solution is correct",
                r"\[refinement_needed\]:\s*no",
                r"\[refinement_needed\]\s*no"
            ]

            for pattern in no_patterns:
                if re.search(pattern, feedback_response, re.IGNORECASE):
                    refinement_needed = False
                    break

            # If regex doesn't match, use the whole response as feedback
            feedback = feedback_response

        # If `not refinement_needed`: break out of loop
        if not refinement_needed:
            break

        # Build `refine_prompt`
        refine_prompt = (
            f"Consider this initial prompt and problem: \n"
            f"\"{prompt}\" \n \n"
            f"Here is my current solution: \n"
            f"\"{current_output}\" \n\n"
            f"Feedback on the solution: \n"
            f"\"{feedback}\" \n\n"
            f"Please provide an improved solution based on this feedback. \n"
            f"Remember to follow the step-by-step approach and end with the final answer in a boxed format."
        )

        # Call `generate_content(refine_prompt)` → `refined_output`
        refined_output = generate_content(refine_prompt)

        # Check if the solution has changed
        if refined_output.strip() == current_output.strip():
            break

        current_output = refined_output
    
    # Use `extract_answer(current_output)` to get the final boxed answer
    answer = extract_answer(current_output)
    # Return that answer (or None if no boxed answer found)
    return answer


# Evaluate self refiner
* modify response generation part to evalute this method.

In [18]:
def evaluate_self_refiner():
    os.makedirs("results", exist_ok=True)
    results_file = "evaluation_results_math500_deepseek_self_refiner.json"
    dataset = load_math500_dataset()
    existing_results = load_existing_results(results_file)
    processed_indexes = {result['index'] for result in existing_results}
    cnt = 0
    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed_indexes :
            continue
        if idx >= 30:
          break
        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])
        ##########################################################
        response = self_refine(problem_text, 3)
        predicted_answer = response
        ##########################################################
        is_correct = compare_answers(correct_answer, predicted_answer)
        result = {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        }
        save_result(results_file, result)
        if is_correct:
          cnt += 1
        print(f"corrects :  {cnt} idx: {idx}")
    final_results = load_existing_results(results_file)
    analyze_results(final_results)

In [None]:
evaluate_self_refiner()

=== Results Summary ===

Total problems: 30

Correct answers: 15

Accuracy: 50.00%

Runtime: 41:16

# 🚀 Overview (Part 2): Implementing A*, Monte Carlo Tree Search (MCTS), and Tree of Thoughts (ToT) 

Ok so now every thing is ready to start part 2, This part aims to explore three sophisticated search and reasoning algorithms—**A\***, **Monte Carlo Tree Search (MCTS)**, and **Tree of Thoughts (ToT)**—to solve challenging mathematical problems, specifically using the MATH-500 dataset and an LLM (Language Model). Before diving into the implementation, we provide a comprehensive overview of each algorithm, highlighting their core mechanisms, practical considerations, and potential challenges.

---

## 🌟 1. A* Search Algorithm

**A*** is an informed search algorithm designed for efficiently finding the shortest path or optimal solution in a search space using heuristics.

### 🔹 Core Principles:
- **Best-first search:** Expands nodes based on a cost function, \(f(n)\), prioritizing paths that seem closer to a goal.
- **Heuristic evaluation:** Uses a heuristic function \(h(n)\) to estimate the cost from the current node to the goal.

### 🔹 Components:
- **Cost function \(f(n) = g(n) + h(n)\)**, where:
  - \(g(n)\): Actual cost from start node to node \(n\).
  - \(h(n)\): Estimated cost from node \(n\) to the goal (heuristic).

### 🔹 Practical Considerations:
- Heuristic function must be **efficient** and **accurate**.
- A good heuristic drastically reduces computation time and search complexity.

### ⚠️ Challenges in Implementation:
- **Designing an effective heuristic:**
  - Challenge to accurately estimate "distance" from partial reasoning steps to the solution.
  - Requires leveraging language models to score plausibility.
- **Computational efficiency:**
  - Heuristic evaluation via LLM queries could be computationally costly if not managed carefully.
- **Admissibility and consistency:**
  - Ideally, heuristic must be admissible (never overestimates the true cost) to guarantee optimality.

---

## 🎲 2. Monte Carlo Tree Search (MCTS)

**MCTS** is a probabilistic algorithm widely used in decision-making scenarios, especially effective in complex problems with uncertain outcomes, such as mathematical reasoning guided by language models.

### 🔹 Core Principles:
MCTS explores decision trees using **randomized simulations** (rollouts) and statistical sampling.

**Four main steps in MCTS**:
1. **Selection**: Uses UCT (Upper Confidence Bound) to balance exploration and exploitation.
2. **Expansion**: Adds new unexplored nodes to the tree.
3. **Simulation**: Conducts rollouts from newly expanded nodes to estimate potential outcomes.
4. **Backpropagation**: Updates statistical measures based on simulation results.

### 🔹 Components:
- **UCT formula** for selection:
  $$
  \text{UCT} = \frac{w_i}{n_i} + C\sqrt{\frac{\ln N}{n_i}}
  $$
  - \(w_i\): Total rewards.
  - \(n_i\): Visits to node \(i\).
  - \(N\): Visits to parent node.
  - \(C\): Exploration constant (usually \(\sqrt{2}\)).

- **Rollout (simulation)**:
  - Typically involves letting the LLM complete reasoning steps to the end, evaluating correctness.

### 🔹 Practical Considerations:
- Effective rollout policies significantly impact accuracy and efficiency.
- Balance exploration (testing new reasoning paths) and exploitation (refining known good solutions).

### ⚠️ Challenges in Implementation:
- **Computational overhead**:
  - Running multiple LLM-based rollouts per node can be slow and computationally expensive.
- **Optimal parameter tuning**:
  - Choosing the exploration constant \(C\) and number of simulations impacts performance significantly.
- **Quality of simulation outcomes**:
  - Poor rollout outcomes (random or inaccurate completions) can misguide the search tree.

---

## 🌳 3. Tree of Thoughts (ToT)

**Tree of Thoughts (ToT)** is specifically designed for structured reasoning tasks with language models, extending their capabilities through explicit evaluation and pruning of reasoning paths.

### 🔹 Core Principles:
- Generate multiple candidate "thoughts" (reasoning paths).
- Evaluate each thought explicitly (often through LLM-based scoring or consistency checking).
- Iteratively prune weaker reasoning paths, keeping the most promising solutions.

### 🔹 Components:
- **Thought generation**: Multiple candidate reasoning steps generated at each node.
- **Thought evaluation**: Explicit scoring (via LLMs) to judge mathematical correctness or plausibility.
- **Pruning**: Remove less promising reasoning branches based on evaluation.

### 🔹 Practical Considerations:
- Explicit node evaluation adds a structured layer of reasoning not present in simpler methods.
- Allows LLMs to reason more deliberately by systematically exploring and eliminating alternatives.

### ⚠️ Challenges in Implementation:
- **Evaluation complexity**:
  - Frequent explicit evaluations by LLM can slow the process.
  - Requires efficient prompting and scoring techniques.
- **Self-consistency**:
  - Maintaining logical consistency across multiple branches can be difficult, especially for complex math problems.
- **Scaling**:
  - Managing multiple branches of reasoning can quickly become computationally expensive without careful control.

---


**Next, we will begin the practical implementation step-by-step.**


# Implementing Tree of Thoughts (ToT)

Before writing any code, it’s essential to map out the steps and functions we need for our Tree of Thoughts (ToT) implementation. In ToT, we aren’t just following one linear chain of reasoning but instead generating a tree of candidate reasoning paths (“thoughts”), evaluating them, and then expanding the most promising ones. We can leverage our existing helper functions (such as those for extracting and normalizing answers) as part of the evaluation process.

Below is an outline of the steps and functions we’ll need:

---

## 1. **Node Representation**

We need a way to represent each node in our reasoning tree. A node could include:
- **Current reasoning text**: The partial solution or thought generated so far.
- **Candidate thoughts**: A list of potential next steps (children nodes).
- **Evaluation score**: A score indicating how promising the node is (based on plausibility or correctness).
- **Metadata**: Such as depth in the tree or a reference to the parent node.

**Potential Functions/Classes:**
- `class Node`: A class that encapsulates the above properties.
- `add_child(self, child_node)`: A method to attach a new candidate thought.

---

## 2. **Candidate Thought Generation**

This function will use the LLM (via our `get_llm_response` function) to generate multiple candidate reasoning steps given a node's current state. 

**Key Points:**
- The prompt should be carefully crafted to ask the LLM for alternative reasoning steps.
- We can use techniques such as few-shot prompting to guide the LLM in generating diverse thoughts.

**Potential Function:**
- `def generate_candidate_thoughts(node: Node, num_candidates: int) -> List[str]:`
  - This function takes the current reasoning state from `node` and returns a list of candidate thoughts (as strings).

---

## 3. **Candidate Thought Evaluation**

Once we have multiple candidate thoughts, we need to score them. The evaluation could be based on:
- **Model’s self-assessment:** Ask the LLM to rate each candidate on a scale (e.g., 1 to 10) for mathematical plausibility or correctness.
- **Heuristics based on helper functions:** Use helper functions like `extract_answer` and `normalize_answer` to check whether a candidate thought moves closer to a correct answer or simplifies the expression.

**Potential Function:**
- `def evaluate_candidate_thought(candidate: str) -> float:`
  - This function might prompt the LLM with the candidate reasoning step, asking, “How plausible or correct is this step?” and return a numeric score.
  - Alternatively, it might combine an LLM score with our own heuristic checks.

---

## 4. **Pruning and Selection**

After evaluating the candidates, we must select the most promising ones for further expansion. Pruning involves:
- Ranking candidate thoughts by their evaluation score.
- Keeping only the top N candidates (to control the tree size).

**Potential Function:**
- `def select_best_candidates(candidates: List[str], scores: List[float], top_n: int) -> List[str]:`
  - This function will combine the candidate list and their scores to select the best ones for further exploration.

---

## 5. **Node Expansion**

For each node, the process is:
1. **Generate candidate thoughts** using the generation function.
2. **Evaluate each candidate** using the evaluation function.
3. **Select the best candidate(s)** using the pruning/selection function.
4. **Expand the tree** by creating child nodes for each selected candidate.

**Potential Function:**
- `def expand_node(node: Node, num_candidates: int, top_n: int) -> None:`
  - This function integrates candidate generation, evaluation, and pruning to add new child nodes to the given `node`.

---

## 6. **Stopping Criteria**

We need clear criteria for when to stop expanding the tree:
- **Complete solution found:** When a node contains a complete solution (e.g., using `extract_answer` to verify that a boxed answer exists).
- **Depth or resource limits:** When a maximum depth is reached or computational resources are constrained.

**Potential Function:**
- `def is_solution(node: Node) -> bool:`
  - This function checks if a node’s reasoning contains a valid, complete answer.
- `def stop_expansion(node: Node, max_depth: int) -> bool:`
  - This function checks if the node has reached the maximum allowed depth.

---

## 7. **Integration with Helper Functions**

Our previous helper functions play a crucial role in the ToT implementation:
- **Extracting and normalizing answers:**  
  Use `extract_answer` and `normalize_answer` to interpret candidate outputs and compare them against the expected solution.
- **Comparison functions:**  
  Use `compare_answers` to help decide if a candidate thought is moving in the right direction.
- **LLM response function:**  
  `get_llm_response` is used both for generating candidate thoughts and possibly for scoring them.

---

## 8. **Overall ToT Flow**

Putting it all together, here is an outline of the overall ToT process:
1. **Initialize the root node** with the initial problem statement.
2. **While the stopping criteria are not met:**
   - For the current node, generate candidate thoughts.
   - Evaluate each candidate.
   - Select and expand the best candidates to form new child nodes.
3. **Once a candidate thought leads to a complete solution:**
   - Use helper functions to verify correctness.
   - Return or record the successful reasoning path.

---

In [20]:
import re
import time
from typing import List, Optional

def score_with_gpt(problem: str, candidate: str) -> float:
    """
    Ask a high‑quality LLM (via get_api_response) to rate the given
    final answer on a scale from 1 (poor) to 10 (excellent).
    """
    # build `eval_prompt` string using `problem` and `candidate`
    eval_prompt = (
        f"You are a math problem evaluator.\n\n"
        f"Problem:\n{problem}\n\n"
        f"Candidate solution step:\n{candidate}\n\n"
        f"Rate this solution step from 1-10 based on these criteria:\n"
        f"- Mathematical correctness\n"
        f"- Relevance to solving the problem\n"
        f"- Progress toward the final answer\n"
        f"Respond with only a number between 1 and 10."
    )
    resp = get_api_response(eval_prompt).strip()

    # use `re.search` to extract the first numeric score from `resp`
    match = re.search(r'(\d+(?:\.\d+)?)', resp)

    # return float(match) or 0.0 on failure
    return float(match.group(1)) if match else 0.0


class Node:
    def __init__(self, state: str, depth: int = 0, parent: Optional['Node'] = None):
        self.state = state
        self.depth = depth
        self.parent = parent
        self.children: List[Node] = []      # easy: list to hold child nodes
        self.score: Optional[float] = None  # will be set once evaluated

    def add_child(self, child_node: 'Node'):
        self.children.append(child_node)

    def __repr__(self):
        # easy: show depth, score, and a truncated preview of `state`
        preview = self.state if len(self.state) < 50 else self.state[:47] + "..."
        return f"Node(depth={self.depth}, score={self.score}, state='{preview}')"

    def get_complete_reasoning_path(self):
        """
        Returns the complete reasoning path from the root to this node as a string.
        """
        path = []
        current_node = self
        while current_node is not None:
            path.append(current_node.state)
            current_node = current_node.parent
        return "\n".join(reversed(path))


class TreeOfThoughts:
    def __init__(
        self,
        num_candidates: int = 3,
        top_n: int = 2,
        max_depth: int = 3,
        verbose: bool = False
    ):
        self.num_candidates = num_candidates
        self.top_n = top_n
        self.max_depth = max_depth
        self.verbose = verbose
        self.root_state: Optional[str] = None

    def generate_candidate_thoughts(self, node: Node) -> List[str]:
        """
        Prompt the LLM to generate candidate thoughts based on the current node's state.
        """
        # compose prompt with node.state
        prompt = (
            f"You are a mathematician solving a step-by-step problem.\n\n"
            f"PROBLEM AND CURRENT SOLUTION:\n{node.get_complete_reasoning_path()}\n\n"
            f"TASK: Generate exactly {self.num_candidates} different next steps to continue this solution.\n\n"
            f"REQUIREMENTS:\n"
            f"- Each step must be NEW and continue directly from where the solution stopped\n"
            f"- Keep each step brief (2-3 sentences maximum)\n"
            f"- Make sure steps are mathematically diverse (different approaches)\n"
            f"- Do NOT repeat any previous work or reasoning\n"
            f"- Do NOT provide complete solutions\n\n"
            f"- If a step lead to the answer, provide the final answer in this format: \\boxed{{your answer here}}."
            f"FORMAT: Numbered list with each step on a new line\n"
            f"1. [first suggestion]\n"
            f"2. [second suggestion]\n"
            f"...\n\n"
        )
        response = get_llm_response(prompt, 1900)

        # Extract content after </think> tag
        parts = response.split("</think>", 1)
        if len(parts) > 1:
            # Take everything after the thinking tag
            post_thinking = parts[1].strip()
            candidates = [line.strip() for line in post_thinking.split('\n') if line.strip()]
        else:
            # If no </think> tag, just split by lines
            candidates = [line.strip() for line in response.split('\n') if line.strip()]

        # Remove numbering prefixes from the candidates (like "1.", "2)", "(3)", etc.)
        candidates = [re.sub(r'^\s*(?:\d+[.):]\s*|\(\d+\)\s*)', '', candidate).strip() for candidate in candidates]

        # If we don't have enough candidates, add lines from the thinking phase
        if len(candidates) < self.num_candidates and len(parts) > 1:
            thinking_lines = [line.strip() for line in parts[0].split('\n') if line.strip()]
            candidates.extend(thinking_lines[:self.num_candidates - len(candidates)])

        # Take only the required number of candidates
        candidates = candidates[:self.num_candidates]

        return candidates

    def evaluate_candidate_thought(self, candidate: str) -> float:
        # simple wrapper around score_with_gpt
        assert self.root_state is not None
        return score_with_gpt(self.root_state, candidate)

    def select_best_candidates(self, candidates: List[str]) -> List[str]:
        scored_candidates = []
        for cand in candidates:
            score = self.evaluate_candidate_thought(cand)
            scored_candidates.append((cand, score))
            time.sleep(0.5)  # Prevent rate limiting

        # Sort by score in descending order
        scored_candidates.sort(key=lambda x: x[1], reverse=True)

        # Return the top N candidates
        return [c[0] for c in scored_candidates[:self.top_n]]

    def expand_node(self, node: Node) -> None:
        if self.verbose:
            print(f"\nExpanding depth {node.depth} state:\n{node.state}\n")

        raw = self.generate_candidate_thoughts(node)
        best = self.select_best_candidates(raw)

        for ans in best:
            child = Node(state=ans, depth=node.depth + 1, parent=node)
            child.score = self.evaluate_candidate_thought(ans)  # evaluate the child

            # if not boxed, retry up to 3 times with a strict prompt
            attempts = 0
            while not extract_answer(child.state) and attempts < 3 and node.depth >= self.max_depth - 1:
                retry_prompt = (
                    f"Based on your work:\n{child.get_complete_reasoning_path()}\n\n"
                    f"Please provide a final answer in this format: \\boxed{{your answer here}}"
                )
                retry_response = get_llm_response(retry_prompt, 1900)
                child.state = retry_response
                attempts += 1

            node.add_child(child)
            if self.verbose:
                print(f"Added child: {child}")

    def is_solution(self, node: Node) -> bool:
        # return True if `extract_answer(node.state)` yields non-empty
        return extract_answer(node.state) is not None

    def stop_expansion(self, node: Node) -> bool:
        return node.depth >= self.max_depth

    def search(self, root_state: str) -> Node:
        """
        Build the tree until solutions found or max depth reached.
        """
        self.root_state = root_state
        root = Node(state=root_state, depth=0)
        frontier = [root]
        while frontier:
            node = frontier.pop(0)
            if self.is_solution(node) or self.stop_expansion(node):
                continue
            self.expand_node(node)
            frontier.extend(node.children)
        return root


# Testing the Tree of Thoughts Method with Minimal Hyperparameters

This cell demonstrates a test run of the Tree of Thoughts (ToT) framework using minimal hyperparameters. The goal is to ensure that a complete final answer (in the format `\boxed{...}`) is extracted from the model's output.

**Key Steps:**

- **Instantiate TOT:**  
  The TOT instance is created with `num_candidates=1`, `top_n=1`, and `max_depth=1` in verbose mode. This minimal setup is used for quick testing.

- **Run the Search:**  
  The TOT search is executed on the sample problem:  
  *"Solve the integral: \( \int_0^1 x^2 \, dx \)"*

- **Print the TOT Tree:**  
  A recursive function (`print_tree`) prints the entire search tree, allowing inspection of each node's state and depth.

- **Extract Final Answer:**  
  All nodes in the tree are collected. If a node is found that contains a final answer (determined via the helper function `extract_answer`), then the node's state is overwritten to display only the final answer in the correct boxed format.  
  If no such node is found, a fallback prompt forces the model to output the final answer exclusively.

This setup helps verify that the TOT framework correctly isolates and formats the final answer, ensuring the result is comparable to the expected output.


In [21]:
# Test the TreeOfThoughts with minimal settings.
# Generation of candidate answers still uses your primary LLM via get_llm_response,
# while evaluation/scoring uses only the GPT/Gemini verifier.

tot = TreeOfThoughts(num_candidates=2,
                     top_n=1,
                     max_depth=3,
                     verbose=True)

initial_problem = "Solve the integral: \\( \\int_0^1 x^2 \\, dx \\)"
tot_tree = tot.search(initial_problem)

# Helper to print the whole tree
def print_tree(node: Node, indent: str = ""):
    print(indent + repr(node))
    for child in node.children:
        print_tree(child, indent + "  ")

print_tree(tot_tree)

# Collect all nodes and find the first complete solution
def collect_all_nodes(node: Node) -> List[Node]:
    nodes = [node]
    for child in node.children:
        nodes.extend(collect_all_nodes(child))
    return nodes

all_nodes = collect_all_nodes(tot_tree)
solution_node = next((n for n in all_nodes if tot.is_solution(n)), None)

print("---")
if solution_node:
    final = extract_answer(solution_node.state)
    solution_node.state = f"\\boxed{{{final}}}"
    print("\nFinal Answer Found:")
    print(solution_node.state)
else:
    # If no solution node was produced, force a final answer via GPT verifier
    fallback_prompt = (
        f"Based on the problem: \"{initial_problem}\", please provide ONLY your final answer "
        "in the exact format \\boxed{...}."
    )
    forced = get_api_response(fallback_prompt).strip()
    final = extract_answer(forced) or forced
    forced = f"\\boxed{{{final}}}"
    print("\nNo complete final answer was found. Forcing final answer:")
    print(forced)


Expanding depth 0 state:
Solve the integral: \( \int_0^1 x^2 \, dx \)

Added child: Node(depth=1, score=10.0, state='Use the antiderivative of x squared, which is (...')

Expanding depth 1 state:
Use the antiderivative of x squared, which is (x³)/3, and evaluate from 0 to 1 to get 1/3.

Added child: Node(depth=2, score=10.0, state='Subtract the value of the antiderivative at the...')

Expanding depth 2 state:
Subtract the value of the antiderivative at the lower limit 0, which is 0.

Added child: Node(depth=3, score=10.0, state='Okay, so I need to solve the integral of x squa...')
Node(depth=0, score=None, state='Solve the integral: \( \int_0^1 x^2 \, dx \)')
  Node(depth=1, score=10.0, state='Use the antiderivative of x squared, which is (...')
    Node(depth=2, score=10.0, state='Subtract the value of the antiderivative at the...')
      Node(depth=3, score=10.0, state='Okay, so I need to solve the integral of x squa...')
---

Final Answer Found:
\boxed{\dfrac{1}{3}}


# Evaluation of the Tree of Thoughts (ToT) Method on the Math500 Dataset

This evaluation framework is designed to test the ToT method on a subset of the Math500 dataset. It provides flexibility in hyperparameter configuration and is set up for both debugging with detailed output and large-scale evaluation. The framework saves results and summarizes key metrics, and it includes mechanisms to force the model to output a final answer in the expected format.

## Key Features

- **Unique Results File:**  
  Uses a dedicated results file (e.g., `evaluation_results_tot_test.json`) to store evaluation data. The file is cleared at the beginning of each run to prevent interference from previous evaluations.

- **Configurable Hyperparameters:**  
  You can adjust parameters such as:
  - `num_candidates`: Number of candidate final answers generated per node.
  - `top_n`: Number of top candidates selected for expanding the tree.
  - `max_depth`: Maximum depth of the search tree.
  
  These parameters enable you to test different reasoning strategies and trade-offs between search breadth and depth.

- **Sample Selection:**  
  The `max_samples` parameter controls the number of problems from the dataset to evaluate. This allows you to start by testing on a single sample and then scale up to 100 or even more problems as desired.

- **Debug and Fallback Mechanisms:**  
  - The framework prints the entire TOT tree for each problem to facilitate inspection.
  - If no node contains a complete final answer (i.e., a final answer wrapped in `\boxed{...}`), a fallback prompt is issued to force the model to output the final answer.
  - Detailed debug outputs help track the process and diagnose any issues in final answer extraction.

- **Final Answer Extraction:**  
  After processing the tree, the system extracts just the final answer from the model's output (using a helper function like `extract_answer`), ensuring that only the final answer (and no chain-of-thought explanations) is compared against the correct answer.

- **Result Saving and Analysis:**  
  The framework saves each problem’s evaluation (including problem text, responses, and correctness) and produces a summary report that includes metrics like total problems, correct answers, and overall accuracy.

## Encouragement for Further Improvement

**Prompt Engineering & Fallback Strategies:**  
The current method forces the model to provide a final answer using a fallback prompt when the initial generation does not meet the required format. While this approach works, it is not perfect:
- **Prompt Tuning:** Experiment with different wording and structure in the prompts. For example, try different phrasings that emphasize "ONLY your final answer" and "no additional explanations" to see if the model can be nudged into generating a cleaner response.
- **Iterative Refinement:** Consider implementing iterative prompt refinement mechanisms or leveraging additional post-processing steps to filter out unwanted chain-of-thought text.
- **Open Research Problem:** The issue of controlling a language model’s output to include only the final answer (and not intermediate reasoning) is an active area of research. There is significant potential to explore improved strategies that maintain reasoning power while enforcing output constraints.

**Scalability:**  
Once you have fine-tuned the hyperparameters and the prompting strategy on a small sample, encourage testing over a larger set (such as 100 or more problems). Evaluating on a larger dataset can help identify trends and potential improvements that might not be apparent on a smaller scale.

**Experiment and Innovate:**  
Do not hesitate to modify the prompts, fallback mechanisms, and even the underlying structure of the TOT class. Every change you experiment with might lead to better results and a deeper understanding of how to steer the model toward producing just the final answer. Your experimentation is key to achieving a more robust and reliable evaluation system.

This framework is designed to be flexible—feel free to tweak the parameters and strategies to suit your research or production needs, and continue to iterate toward better performance!


In [None]:
import os
from tqdm import tqdm

def evaluate_tot(max_samples: int = 1):
    """
    Evaluate the Tree of Thoughts (ToT) method on a subset of the Math500 dataset.
    Uses GPT/Gemini exclusively for verification and fallback.
    Assumes the existence of helper functions:
      - load_math500_dataset()
      - extract_answer(solution_text)
      - compare_answers(correct_answer, predicted_answer)
      - save_result(results_file, result_dict)
      - load_existing_results(results_file)
      - analyze_results(results_list)
      - get_llm_response(prompt)     # for initial candidate generation
      - get_api_response(prompt)     # for GPT/Gemini–based verification
    """
    os.makedirs("results", exist_ok=True)
    results_file = "results/evaluation_results_tot_test.json"
    if os.path.exists(results_file):
        os.remove(results_file)

    dataset = load_math500_dataset()
    existing = load_existing_results(results_file)
    processed = {r['index'] for r in existing}

    tot = TreeOfThoughts(num_candidates=3, top_n=1, max_depth=3, verbose=True)
    correct_count = 0
    evaluated = 0

    def collect_all_nodes(node):
        nodes = [node]
        for child in node.children:
            nodes.extend(collect_all_nodes(child))
        return nodes

    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed or evaluated >= max_samples:
            continue

        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])

        # Run the Tree‑of‑Thoughts search
        tot_tree = tot.search(problem_text)

        # Debug print of the entire tree
        print(f"\n--- DEBUG: Full TOT tree for problem index {idx} ---")
        def print_tree(node, indent=""):
            print(indent + repr(node))
            for c in node.children:
                print_tree(c, indent + "  ")
        print_tree(tot_tree)
        print("--- End of TOT tree ---\n")

        # Find first node with a boxed answer
        all_nodes = collect_all_nodes(tot_tree)
        solution_node = next((n for n in all_nodes if tot.is_solution(n)), None)

        if solution_node:
            response = solution_node.state
            print("DEBUG: Found solution node:", response)
        else:
            # Fallback: ask GPT/Gemini directly for final answer
            print("DEBUG: No solution node found. Using GPT fallback.")
            fallback_prompt = (
                f"Based on the problem: \"{problem_text}\", provide ONLY your final answer "
                "in the exact format \\boxed{<final answer>}. Do not include any chain-of-thought."
            )
            response = get_api_response(fallback_prompt)
            print("DEBUG: Fallback response:", response)

        predicted = extract_answer(response) or ""
        if not predicted:
            print("DEBUG: predicted_answer is empty. Raw response was:\n", response)

        is_correct = compare_answers(correct_answer, predicted)
        save_result(results_file, {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted,
            "is_correct": is_correct
        })

        if is_correct:
            correct_count += 1
        evaluated += 1
        print(f"Correct: {correct_count}/{evaluated} | Index: {idx}")

    final = load_existing_results(results_file)
    analyze_results(final)

# Example: test a single sample
evaluate_tot(max_samples=30)

=== Results Summary ===

Total problems: 30

Correct answers: 16

Accuracy: 53.33%

Runtime: 1:10:36

# A* Search Algorithm for Mathematical Reasoning

The A* (A-star) search algorithm is a popular informed search method that combines both actual cost and an estimated cost to reach the goal. When applied to mathematical problem-solving with language models, each node in the search tree represents a partial reasoning process. The goal is to guide the search toward the correct final answer efficiently.

## Core Concepts

- **Nodes and States:**  
  Each node represents a state in the reasoning process—a partial solution or chain-of-thought step. The root node is the initial problem, and child nodes represent possible next steps in reasoning.

- **Cost Function (g(n)):**  
  This function measures the cost accumulated from the start node to the current node. In our context, it might represent the complexity or length of the reasoning chain so far.

- **Heuristic Function (h(n)):**  
  A heuristic estimates the cost (or “distance”) from the current node to the goal (a complete final answer). For mathematical reasoning, this could be designed to reflect how promising the current partial solution is—possibly by prompting the LLM to provide a confidence or plausibility score.

- **Evaluation Function (f(n)):**  
  A* uses the function:
  \[
  f(n) = g(n) + h(n)
  \]
  to choose which node to expand next. Nodes with lower f(n) are expanded first, steering the search toward the most promising reasoning paths.

## How A* Works in Mathematical Reasoning

1. **Initialization:**  
   The algorithm starts with the initial problem as the root node, with an initial cost \(g(n)=0\).

2. **Expansion:**  
   From the current node, the model generates several potential next steps (child nodes). Each child node represents a possible continuation of the reasoning process.

3. **Cost and Heuristic Calculation:**  
   - **g(n):** Represents the cost accumulated so far (e.g., the number of reasoning steps taken).
   - **h(n):** An estimate of how “far” the current state is from a complete final answer. This can be derived via LLM-based evaluations or comparisons to known correct patterns.

4. **Priority Queue and Node Selection:**  
   The algorithm uses a priority queue to maintain nodes, sorted by their \( f(n) \) value. The node with the smallest \(f(n)\) (i.e., the most promising combination of current cost and estimated remaining cost) is expanded next.

5. **Goal Test:**  
   The process continues until a node is found that meets the goal—a node whose state contains a complete final answer in the expected format (e.g., a LaTeX expression wrapped in `\boxed{...}`).

## Challenges in Applying A* to Reasoning

- **Heuristic Design:**  
  Defining an effective h(n) is challenging. The heuristic must correlate well with the true “distance” to a correct final answer. For language models, this might involve model confidence scores or custom prompt evaluations.

- **Balancing Exploration and Exploitation:**  
  Overemphasis on g(n) might favor shorter, less-complete reasoning chains, while too much reliance on h(n) might cause the search to overestimate the quality of partially correct answers.

- **Computational Expense:**  
  Evaluating each node’s heuristic (potentially via additional LLM queries) can be computationally expensive, especially in a large search space.

- **Scalability:**  
  The state space for reasoning is vast. A well-tuned A* algorithm can efficiently prune irrelevant paths, but without a robust heuristic, the number of nodes to explore can grow exponentially.



Next, we will move on to the code implementation of the A* algorithm for our reasoning framework.


In [23]:
import re
import time
import heapq
from typing import List, Optional


class AStarNode:
    def __init__(self, state: str, g: float = 0.0, parent: Optional['AStarNode'] = None):
        self.state = state                      # easy: store the current problem or partial answer
        self.g = g                              # easy: cost so far (depth)
        self.h: float = 0.0                     # will be set by heuristic evaluator
        self.f: float = 0.0                     # f = g + h
        self.parent = parent                    # link back to parent for solution path
        self.children: List[AStarNode] = []     # easy: list to hold generated successors

    def __lt__(self, other: 'AStarNode'):
        return self.f < other.f                 # easy: allow heapq to compare nodes by f‑value

    def __repr__(self):
        preview = self.state if len(self.state) < 50 else self.state[:47] + "..."
        return f"AStarNode(g={self.g}, h={self.h}, f={self.f}, state='{preview}')"


class AStarSearch:
    def __init__(
        self,
        num_candidates: int = 3,
        max_depth: int = 3,
        verbose: bool = False,
        max_fallback: int = 3
    ):
        self.num_candidates = num_candidates    # how many answers to generate per node
        self.max_depth = max_depth              # search will stop at this depth
        self.verbose = verbose                  # if True, print debug info
        self.max_fallback = max_fallback        # retries for enforcing boxed format
        self.root_problem: Optional[str] = None  # will hold the original problem text

    def is_solution(self, state: str) -> bool:
        # return True if `extract_answer(state)` yields a non‑empty string
        return extract_answer(state) is not None

    def generate_candidates(self, node: AStarNode) -> List[str]:
        """
        Ask the LLM for `num_candidates` boxed answers to node.state.
        """
        # compose a prompt with node.state asking for LaTeX \\boxed{...} answers
        prompt = f"""You are solving a mathematics problem.
Given the following problem or partial solution:

{node.state}

Generate {self.num_candidates} different possible final answers.
Each answer should be clearly marked with a LaTeX \\boxed{{...}} command.
If it's a final answer, make sure it's enclosed in \\boxed{{...}}.
Write each answer on a new line. These answers should be distinct.
"""
        response = get_llm_response(prompt)
        
        # Split the response into non-empty lines
        lines = [line.strip() for line in response.split('\n') if line.strip()]
        
        # If fewer than num_candidates, replicate lines to reach that count
        while len(lines) < self.num_candidates:
            lines.append(lines[0] if lines else "No solution found.")
            
        # Return exactly num_candidates answer strings
        return lines[:self.num_candidates]

    def evaluate_heuristic(self, state: str) -> float:
        """
        Use GPT (via get_api_response) to score how incomplete a candidate is:
        0 means perfect boxed answer; higher means more incomplete.
        """
        assert self.root_problem is not None, "Root problem must be set before heuristic evaluation"
        
        # build `eval_prompt` using self.root_problem and the current `state`
        eval_prompt = f"""Given this original problem:
{self.root_problem}

And this current state of the solution:
{state}

On a scale of 0 to 10:
- Score 0 if this contains a complete, correctly formatted final answer inside \\boxed{{...}}
- Score higher (up to 10) the more work remains to reach a final boxed answer

Return ONLY a numeric score without explanation.
"""
        response = get_api_response(eval_prompt).strip()
        
        # extract the first numeric value with `re.search`
        match = re.search(r'(\d+(?:\.\d+)?)', response)
        
        # return that float, or a fallback like `self.max_depth * 10` if parsing fails
        return float(match.group(1)) if match else self.max_depth * 10

    def expand_node(self, node: AStarNode) -> List[AStarNode]:
        if self.verbose:
            print(f"\nExpanding node at depth {node.g}:\n{node.state}\n")
            
        # if self.root_problem is None, set it to node.state
        if self.root_problem is None:
            self.root_problem = node.state
            
        candidates = self.generate_candidates(node)
        children: List[AStarNode] = []

        for cand in candidates:
            attempts = 0
            while not self.is_solution(cand) and attempts < self.max_fallback:
                fallback_prompt = f"""You are a mathematics expert. For this problem:
{self.root_problem}

Using the partial solution:
{cand}

Provide ONLY the final answer enclosed in \\boxed{{...}} LaTeX format.
DO NOT show any work, ONLY the final boxed answer.
"""
                response = get_llm_response(fallback_prompt)
                cand = response.strip()
                attempts += 1
                if self.verbose:
                    print(f"Fallback attempt {attempts}: {'Solution found' if self.is_solution(cand) else 'Still no solution'}")
            
            child = AStarNode(state=cand, g=node.g + 1, parent=node)
            child.h = self.evaluate_heuristic(child.state)
            child.f = child.g + child.h
            node.children.append(child)
            children.append(child)
            
            if self.verbose:
                print(f"Generated child: {child}")
        
        return children

    def search(self, initial_problem: str) -> Optional[AStarNode]:
        """
        Run A* until a boxed solution is found or max_depth is exceeded.
        """
        self.root_problem = initial_problem
        root = AStarNode(state=initial_problem, g=0.0)
        root.h = self.evaluate_heuristic(root.state)
        root.f = root.g + root.h
        frontier: List[AStarNode] = []
        heapq.heappush(frontier, root)

        while frontier:
            node = heapq.heappop(frontier)
            if self.verbose:
                print(f"Expanding: {node}")
            
            if self.is_solution(node.state):
                return node
            
            if node.g >= self.max_depth:
                continue
            
            for child in self.expand_node(node):
                heapq.heappush(frontier, child)

        return None


# Test Code for A* Search Method with Minimal Hyperparameters

Below is a description of the test procedure for the A* search method using minimal hyperparameters. This test ensures that the algorithm returns only the final answer in the proper format, without extra chain-of-thought text.

- **Initialization:**
  - An A* search instance is created with the following settings:
    - `num_candidates`: 1 (only one candidate is generated per node)
    - `max_depth`: 1 (the search tree is kept shallow for testing)
    - `verbose`: True (detailed debug information is printed)
    - `max_fallback`: 3 (up to three fallback attempts are made to force a final answer)
  - The initial problem is set as:
    - "Solve the integral: \( \int_0^1 x^2 \, dx \)"

- **Search Execution:**
  - The A* search is executed on the initial problem to obtain a solution node.
  
- **Tree Printing:**
  - A recursive function (e.g., `print_astar_tree`) is used to print the entire A* search tree, starting from the root. This allows inspection of all nodes and the reasoning process.

- **Final Answer Extraction:**
  - If a solution node is found, the algorithm backtracks to the root to print the entire tree.
  - Then it extracts the final answer from the solution node using a helper function (e.g., `extract_answer`). The state of the solution node is reformatted to display only the final answer in the exact format (e.g., `\boxed{<final answer>}`).

- **Fallback Handling:**
  - If no solution node is found, a fallback prompt is issued. This prompt instructs the model to provide ONLY its final answer in the correct format, with no extra explanation.
  - The fallback final answer is then printed.

This test code is designed to verify that, with minimal hyperparameters, the A* search method consistently returns a complete final answer, ensuring the output is directly comparable with the expected result.


In [24]:
# Test Code for A* Search Method with Minimal Hyperparameters

# Initialize the A* search instance with minimal settings.
# Candidate generation still uses get_llm_response(...),
# but all verification and final fallback use get_api_response(...) (Gemini).
astar = AStarSearch(
    num_candidates=1,
    max_depth=1,
    verbose=True,
    max_fallback=3
)

initial_problem = "Solve the integral: \\( \\int_0^1 x^2 \\, dx \\)"
solution_node = astar.search(initial_problem)

# Recursive printer for the A* search tree.
def print_astar_tree(node, indent=""):
    print(indent + repr(node))
    for child in node.children:
        print_astar_tree(child, indent + "  ")

if solution_node:
    # Backtrack to the root.
    root = solution_node
    while root.parent is not None:
        root = root.parent

    # Print the entire tree from the root.
    print_astar_tree(root)

    # Extract and normalize the final answer.
    final = extract_answer(solution_node.state)
    if final:
        solution_node.state = f"\\boxed{{{final}}}"
    print("\nFinal Answer Found:")
    print(solution_node.state)
else:
    # No solution found: use GPT/Gemini directly for the final answer.
    fallback_prompt = (
        f"Based on the problem: \"{initial_problem}\", please provide ONLY your final answer "
        "in the exact format \\boxed{<final answer>}. Do not include any chain-of-thought or extra explanation."
    )
    forced_final_answer = get_api_response(fallback_prompt).strip()
    final = extract_answer(forced_final_answer) or forced_final_answer
    forced_final_answer = f"\\boxed{{{final}}}"
    print("\nNo complete solution node found. Forced final answer:")
    print(forced_final_answer)


Expanding: AStarNode(g=0.0, h=2.0, f=2.0, state='Solve the integral: \( \int_0^1 x^2 \, dx \)')

Expanding node at depth 0.0:
Solve the integral: \( \int_0^1 x^2 \, dx \)

Fallback attempt 1: Solution found
Generated child: AStarNode(g=1.0, h=10.0, f=11.0, state='Okay, so I need to solve the integral \( \int_0...')
Expanding: AStarNode(g=1.0, h=10.0, f=11.0, state='Okay, so I need to solve the integral \( \int_0...')
AStarNode(g=0.0, h=2.0, f=2.0, state='Solve the integral: \( \int_0^1 x^2 \, dx \)')
  AStarNode(g=1.0, h=10.0, f=11.0, state='Okay, so I need to solve the integral \( \int_0...')

Final Answer Found:
\boxed{\dfrac{1}{3}}


# Evaluation of the A* Search Method on the Math500 Dataset

This evaluation framework is designed to test the A* search method on a subset of the Math500 dataset. It provides flexibility in hyperparameter configuration and is set up for both detailed output and large-scale evaluation. The framework saves results, summarizes key metrics, and includes mechanisms to force the model to output a final answer in the expected format.

## Key Features

- **Unique Results File:**  
  Uses a dedicated results file (e.g., `evaluation_results_astar_test.json`) to store evaluation data. The file is cleared at the start of each run to prevent interference from previous evaluations.

- **Configurable Hyperparameters:**  
  You can adjust parameters such as:
  - `num_candidates`: Number of candidate final answers generated per node.
  - `max_depth`: Maximum depth of the search tree.
  
  These settings enable you to test different reasoning strategies and trade-offs between search exploration depth and the precision of the final answer.

- **Sample Selection:**  
  The `max_samples` parameter controls the number of problems from the dataset to evaluate. This allows you to start by testing on a single sample and then scale up to 100 or more problems as desired.

- **Fallback Mechanisms:**  
  If no node in the A* search tree contains a complete final answer (i.e., one wrapped in `\boxed{...}`), a fallback prompt is issued to force the model to output only the final answer. This guarantees that every evaluated problem produces a final answer in the proper format.

- **Final Answer Extraction:**  
  After processing the search tree, the framework extracts just the final answer from the model’s output (using a helper function like `extract_answer`). This ensures that only the final answer is compared against the expected solution, without any additional chain-of-thought text.

- **Result Saving and Analysis:**  
  The framework saves detailed evaluation data—including the problem text, the model's raw response, the extracted final answer, and correctness—and produces a summary report with metrics such as total problems evaluated, number of correct answers, and overall accuracy.

## Encouragement for Further Improvement

**Prompt Engineering & Fallback Strategies:**  
The current method forces the model to provide a final answer using a fallback prompt if the initial search does not yield the required format. Experiment with:
- **Prompt Tuning:** Adjust the wording and structure to further emphasize "ONLY your final answer" and "no extra text."
- **Iterative Refinement:** Consider iterative prompt refinement or additional post-processing to isolate the final answer more reliably.
- **Innovative Approaches:** This challenge of extracting only the final answer is an active area of research. Exploring new strategies may lead to better performance and more robust results.

**Scalability:**  
Once the hyperparameters and prompting strategy are fine-tuned on a small set of problems, scale the evaluation to a larger sample (e.g., 100+ problems). A broader evaluation can reveal trends and help identify further improvements.

**Experiment and Innovate:**  
Feel free to modify prompts, adjust hyperparameters, and refine fallback mechanisms. Comparing the A* search method with other approaches, such as the Tree of Thoughts method, can provide valuable insights. Your experimentation is key to developing a more robust and reliable evaluation system.

This flexible framework is designed to meet diverse research and production needs—keep iterating and exploring until you achieve optimal results!


In [None]:
import os
import random
from tqdm import tqdm

def evaluate_astar_random(max_samples: int = 1):
    """
    Evaluate the A* search method on randomly selected problems from the Math500 dataset.
    Uses GPT/Gemini (via get_api_response) for any fallback final-answer requests.
    Assumes existence of:
      - load_math500_dataset()
      - extract_answer(solution_text)
      - compare_answers(correct_answer, predicted_answer)
      - save_result(results_file, result_dict)
      - load_existing_results(results_file)
      - analyze_results(results_list)
      - get_llm_response(prompt)   # for candidate generation
      - get_api_response(prompt)   # for GPT/Gemini–based fallback
    """
    os.makedirs("results", exist_ok=True)
    results_file = "results/evaluation_results_astar_random_test.json"
    
    # Clear previous test results
    if os.path.exists(results_file):
        os.remove(results_file)
    
    dataset = load_math500_dataset()
    total = len(dataset)
    
    # Pick random unique indices
    selected = set()
    while len(selected) < max_samples:
        selected.add(random.randint(0, total - 1))
    
    astar = AStarSearch(num_candidates=1, max_depth=1, verbose=True, max_fallback=3)
    correct_count = 0
    evaluated = 0
    
    def collect_all_nodes(node):
        nodes = [node]
        for c in node.children:
            nodes.extend(collect_all_nodes(c))
        return nodes
    
    for idx in selected:
        item = dataset[idx]
        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])
        
        solution_node = astar.search(problem_text)
        if solution_node:
            response = solution_node.state
            print("DEBUG: Found solution node:", response)
        else:
            print("DEBUG: No solution node found. Using GPT fallback.")
            fallback_prompt = (
                f"Based on the problem: \"{problem_text}\", provide ONLY your final answer "
                "in the exact format \\boxed{<final answer>}. Do not include any chain-of-thought or explanation."
            )
            response = get_api_response(fallback_prompt).strip()
            print("DEBUG: Fallback response:", response)
        
        predicted_answer = extract_answer(response)
        if not predicted_answer:
            print("DEBUG: predicted_answer is empty. Raw response was:\n", response)
        
        is_correct = compare_answers(correct_answer, predicted_answer or "")
        
        save_result(results_file, {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        })
        
        if is_correct:
            correct_count += 1
        evaluated += 1
        print(f"Correct: {correct_count}/{evaluated} | Index: {idx}")
    
    final_results = load_existing_results(results_file)
    analyze_results(final_results)

# Example usage:
evaluate_astar_random(max_samples=30)
# To test with different parameters, adjust max_samples.


=== Results Summary ===

Total problems: 30

Correct answers: 18

Accuracy: 60.00%


# Monte Carlo Tree Search (MCTS) for Mathematical Reasoning

Monte Carlo Tree Search (MCTS) is a probabilistic search algorithm particularly well-suited for decision-making tasks in large, complex search spaces. In the context of mathematical problem solving with language models, each node in the search tree represents a partial reasoning step. MCTS incrementally builds the tree by exploring the most promising reasoning paths through random simulations, balancing exploration of new paths with exploitation of those that appear promising.

## Core Components of MCTS

MCTS operates in four main stages:

1. **Selection:**  
   Starting at the root (the initial problem), the algorithm traverses the tree by selecting child nodes based on a policy that balances two factors:  
   - **Exploitation:** Favoring nodes that have already shown high promise (i.e., those with a good reward or low cost).  
   - **Exploration:** Giving a chance to less-visited nodes to discover potentially promising new paths.  
   This balance is often managed by a criterion such as the Upper Confidence Bound (UCB).

2. **Expansion:**  
   When a leaf node is reached—one that has not been fully expanded—the algorithm expands it by generating one or more child nodes. Each new node represents a possible next step in the reasoning process, such as a candidate final answer generated by the language model.

3. **Simulation (Rollout):**  
   From the newly expanded node, the algorithm performs a simulation (or rollout) to estimate the outcome if that reasoning path were followed to completion. For mathematical reasoning, this may involve prompting the language model to complete the remaining reasoning and produce a final answer. The outcome of the simulation provides an estimated reward or cost for that path.

4. **Backpropagation:**  
   The result of the simulation is then propagated back up the tree. Each node along the path has its evaluation updated based on the outcome, which in turn refines the selection policy for future iterations. This backpropagation ensures that nodes contributing to more promising outcomes are prioritized.

## How MCTS Applies to Mathematical Reasoning

- **State Representation:**  
  Each node encapsulates a partial solution or chain-of-thought produced by the language model. The goal is to eventually obtain a final answer formatted in a concise manner (e.g., wrapped in `\boxed{...}`).

- **Reward and Heuristic:**  
  The reward signal derived from the simulation reflects how close a reasoning path is to a correct final answer. The model's evaluations—such as plausibility scores—help guide the search by penalizing incomplete or incorrect paths.

- **Balancing Exploration and Exploitation:**  
  MCTS effectively manages the trade-off between exploring new, untested reasoning paths and exploiting those paths that have already demonstrated potential. This balance is critical in navigating the vast space of possible reasoning steps.

## Challenges and Considerations

- **Simulation Cost:**  
  Running multiple simulations per node can be computationally intensive, especially when each simulation involves multiple language model calls.

- **Reward Signal Design:**  
  Defining an accurate and meaningful reward (or evaluation) for partial solutions is challenging. The reward must correlate well with the likelihood of ultimately arriving at a complete and correct final answer.

- **Parameter Tuning:**  
  Effective implementation of MCTS requires careful tuning of parameters like the exploration constant and the number of simulations per node. This tuning is essential to balance the search effectively.

- **Ensuring Concise Final Answers:**  
  One of the key objectives is to force the model to output only the final answer (without additional chain-of-thought text). This requires precise prompt engineering and robust fallback strategies in both candidate generation and simulation phases.

## Conclusion

MCTS provides a powerful framework for exploring the reasoning process in language models by combining random simulation with informed backpropagation. Through iterative exploration of promising reasoning paths and careful balancing of exploration and exploitation, MCTS aims to identify the most promising route to a final answer. With a strong focus on ensuring that only a concise final answer is produced (wrapped in a format like `\boxed{...}`), this method offers significant potential for enhancing mathematical problem solving.

By adjusting hyperparameters and refining the simulation and evaluation processes, you can experiment with different configurations to improve the efficiency and accuracy of the final answers. This makes MCTS a flexible and promising approach for further research and practical applications in guided reasoning with language models.


In [27]:
import re
import time
import math
import heapq
from typing import List, Optional


def evaluate_with_gpt(problem: str, candidate: str) -> float:
    """
    Ask the verifier LLM (via get_api_response) whether the boxed answer is correct.
    Returns 1.0 for “yes”, 0.0 for “no”, or an intermediate score if provided.
    """
    prompt = (
        "You are a precise math grader.\n\n"
        f"Problem:\n\"{problem}\"\n\n"
        "Final answer to check (in \\boxed{...} format):\n"
        f"\"{candidate}\"\n\n"
        "Is this answer correct? Reply with a number between 0 (incorrect) and 1 (fully correct)."
    )
    # Send `prompt` to `get_api_response`, strip whitespace
    response = get_api_response(prompt).strip()
    # Extract the first numeric match with `re.search`
    match = re.search(r'(\d+(?:\.\d+)?)', response)
    # Convert to float (fallback to 0.0 on failure)
    score = float(match.group(1)) if match else 0.0
    # Clamp the result to [0.0, 1.0] and return it
    return max(0.0, min(1.0, score))


class MCTSNode:
    def __init__(self, state: str, parent: Optional['MCTSNode'] = None):
        self.state = state
        self.parent = parent
        self.children: List[MCTSNode] = []
        self.visits = 0
        self.total_reward = 0.0

    def add_child(self, child: 'MCTSNode'):
        self.children.append(child)

    def __repr__(self):
        preview = self.state if len(self.state) < 50 else self.state[:47] + "..."
        return f"MCTSNode(visits={self.visits}, reward={self.total_reward:.2f}, state='{preview}')"


class MCTSSearch:
    def __init__(
        self,
        num_simulations: int = 10,
        exploration_const: float = 1.41,
        max_depth: int = 3,
        max_fallback: int = 3,
        num_candidates: int = 1,
        verbose: bool = False
    ):
        self.num_simulations = num_simulations
        self.exploration_const = exploration_const
        self.max_depth = max_depth
        self.max_fallback = max_fallback
        self.num_candidates = num_candidates
        self.verbose = verbose
        self.root_problem: Optional[str] = None

    def is_solution(self, state: str) -> bool:
        # return True if `extract_answer(state)` yields a non-empty boxed answer
        return extract_answer(state) is not None

    def generate_candidates(self, state: str) -> List[str]:
        # build a prompt asking for `self.num_candidates` LaTeX \\boxed{...} answers to `state`
        prompt = f"""You are solving a mathematics problem.
Given the following problem or partial solution:

{state}

Generate {self.num_candidates} different possible final answers.
Each answer should be clearly marked with a LaTeX \\boxed{{...}} command.
If it's a final answer, make sure it's enclosed in \\boxed{{...}}.
Write each answer on a new line. These answers should be distinct.
"""
        response = get_llm_response(prompt)

        # split response into non-empty lines
        lines = [line.strip() for line in response.split('\n') if line.strip()]

        # If fewer than num_candidates, replicate lines to reach that count
        while len(lines) < self.num_candidates:
            lines.append(lines[0] if lines else "No solution found.")

        # Return exactly num_candidates answer strings
        return lines[:self.num_candidates]

    def expand(self, node: MCTSNode) -> List[MCTSNode]:
        if self.verbose:
            print(f"\nExpanding node (depth {self._depth(node)}):\n{node.state}\n")

        # ensure we remember the original problem
        if self.root_problem is None:
            self.root_problem = node.state

        candidates = self.generate_candidates(node.state)
        children: List[MCTSNode] = []

        for cand in candidates:
            final = cand
            attempts = 0

            while not self.is_solution(final) and attempts < self.max_fallback:
                fallback_prompt = f"""Given this problem:
{self.root_problem}

And your current answer:
{final}

Provide ONLY the final answer enclosed in \\boxed{{...}} LaTeX format.
DO NOT show any work, ONLY the final boxed answer.
"""
                response = get_llm_response(fallback_prompt)
                final = response.strip()
                attempts += 1
                if self.verbose:
                    print(f"Fallback attempt {attempts}: {'Solution found' if self.is_solution(final) else 'Still no solution'}")

            child = MCTSNode(state=final, parent=node)
            node.add_child(child)
            children.append(child)

        return children

    def simulate(self, node: MCTSNode) -> float:
        """
        Instead of a blind rollout, directly evaluate the final boxed answer.
        """
        if self.root_problem is None:
            return 0.0

        if not self.is_solution(node.state):
            fallback_prompt = f"""Given this problem:
{self.root_problem}

Provide ONLY the final answer enclosed in \\boxed{{...}} LaTeX format.
DO NOT show any work, ONLY the final boxed answer.
"""
            response = get_llm_response(fallback_prompt)
            node.state = response.strip()

        return evaluate_with_gpt(self.root_problem, node.state)

    def ucb_score(self, child: MCTSNode, parent_visits: int) -> float:
        if child.visits == 0:
            return float('inf')

        exploit = child.total_reward / child.visits
        explore = self.exploration_const * math.sqrt(math.log(parent_visits) / child.visits)
        return exploit + explore

    def select(self, root: MCTSNode) -> MCTSNode:
        # starting at `root`, repeatedly pick the child with highest `ucb_score`
        # until you reach a node with no children, then return it
        node = root
        while node.children:
            best_score = float('-inf')
            best_child = None
            
            for child in node.children:
                score = self.ucb_score(child, node.visits)
                if score > best_score:
                    best_score = score
                    best_child = child
                    
            if best_child is None:
                break
            node = best_child
            
        return node

    def backpropagate(self, node: MCTSNode, reward: float):
        while node is not None:
            node.visits += 1
            node.total_reward += reward
            node = node.parent

    def _depth(self, node: MCTSNode) -> int:
        d = 0
        while node.parent:
            d += 1
            node = node.parent
        return d

    def search(self, initial_problem: str) -> Optional[MCTSNode]:
        """
        Perform MCTS using GPT as the verifier:
        - Selection: `select`
        - Expansion: `expand`
        - Simulation/Backprop: `simulate` + `backpropagate`
        """
        self.root_problem = initial_problem
        root = MCTSNode(state=initial_problem)
        
        for _ in range(self.num_simulations):
            leaf = self.select(root)
            
            if self._depth(leaf) >= self.max_depth:
                reward = self.simulate(leaf)
                self.backpropagate(leaf, reward)
                continue
                
            children = self.expand(leaf)
            if children:
                reward = self.simulate(children[0])
                self.backpropagate(children[0], reward)
                
        if root.children:
            best_child = max(root.children, key=lambda c: c.total_reward / max(c.visits, 1))
            return best_child
            
        return root


# Test Code for MCTS Search Method with Minimal Hyperparameters

Below is a description of the test procedure for the MCTS search method using minimal hyperparameters. This test ensures that the algorithm returns only the final answer in the proper format, without extra chain-of-thought text.

- **Initialization:**
  - An MCTS search instance is created with the following settings:
    - `num_simulations`: 5 (number of simulations to run for exploring reasoning paths)
    - `exploration_const`: 1.41 (parameter to balance exploration and exploitation)
    - `max_depth`: 1 (the search tree is kept shallow for testing)
    - `max_fallback`: 3 (up to three fallback attempts are made to force a final answer)
    - `num_candidates`: 1 (only one candidate is generated per node)
    - `verbose`: True (detailed debug information is printed)
  - The initial problem is set as:
    - "Solve the integral: \( \int_0^1 x^2 \, dx \)"

- **Search Execution:**
  - The MCTS search is executed on the initial problem to obtain a solution node.

- **Tree Printing:**
  - A recursive function (e.g., `print_mcts_tree`) prints the entire MCTS search tree starting from the root. This allows inspection of all nodes and the reasoning process.

- **Final Answer Extraction:**
  - If a solution node is found, the algorithm backtracks to the root and prints the full tree.
  - The final answer is then extracted from the solution node using a helper function (e.g., `extract_answer`) and reformatted to display only the final answer in the exact format (e.g., `\boxed{<final answer>}`), ensuring that no extra explanation or chain-of-thought text is present.

- **Fallback Handling:**
  - If no solution node is found, a fallback prompt is issued that instructs the model to provide ONLY the final answer in the correct format, with no additional explanation.
  - The fallback final answer is printed.

This test code is designed to verify that, with minimal hyperparameters, the MCTS search method consistently returns a complete final answer. Users are encouraged to experiment with these hyperparameters and refine the prompts to further improve performance and compare the results with other methods.


In [28]:
# Test Code for MCTS Search Method with Minimal Hyperparameters

# Initialize the MCTS search instance with minimal settings.
# Generation uses get_llm_response(...), evaluation/fallback uses get_api_response(...)
mcts = MCTSSearch(
    num_simulations=5,
    exploration_const=1.41,
    max_depth=1,
    max_fallback=3,
    num_candidates=1,
    verbose=True
)

initial_problem = "Solve the integral: \\( \\int_0^1 x^2 \\, dx \\)"
solution_node = mcts.search(initial_problem)

# Recursive printer for the MCTS tree.
def print_mcts_tree(node, indent=""):
    print(indent + repr(node))
    for child in node.children:
        print_mcts_tree(child, indent + "  ")

if solution_node:
    # Backtrack to the root node.
    root = solution_node
    while root.parent is not None:
        root = root.parent

    # Print the entire tree from the root.
    print_mcts_tree(root)

    # Extract and normalize the final boxed answer.
    final = extract_answer(solution_node.state)
    if final:
        solution_node.state = f"\\boxed{{{final}}}"
    print("\nFinal Answer Found:")
    print(solution_node.state)
else:
    # No solution found: use GPT/Gemini verifier directly for final answer.
    fallback_prompt = (
        f"Based on the problem: \"{initial_problem}\", please provide ONLY your final answer "
        "in the exact format \\boxed{<final answer>}. Do not include any chain-of-thought or extra explanation."
    )
    forced_final_answer = get_api_response(fallback_prompt).strip()
    final = extract_answer(forced_final_answer) or forced_final_answer
    forced_final_answer = f"\\boxed{{{final}}}"
    print("\nNo complete solution found. Forced final answer:")
    print(forced_final_answer)



Expanding node (depth 0):
Solve the integral: \( \int_0^1 x^2 \, dx \)

Fallback attempt 1: Solution found
MCTSNode(visits=5, reward=5.00, state='Solve the integral: \( \int_0^1 x^2 \, dx \)')
  MCTSNode(visits=5, reward=5.00, state='Okay, so I have to solve the integral \( \int_0...')

Final Answer Found:
\boxed{\dfrac{1}{3}}


# Evaluation of MCTS Search Method on the Math500 Dataset

This evaluation framework is designed to test the MCTS search method on a subset of the Math500 dataset. The framework is highly configurable via hyperparameters and forces the model to output only the final answer in the expected format (e.g., `\boxed{<final answer>}`) with no additional chain-of-thought text.

## Key Features

- **Unique Results File:**  
  A dedicated results file (e.g., `evaluation_results_mcts_test.json`) is used to save evaluation data. The file is cleared at the beginning of each run to ensure a fresh start without interference from previous evaluations.

- **Configurable Hyperparameters:**  
  You can adjust parameters such as:
  - `num_simulations`: Number of MCTS simulations used to explore reasoning paths.
  - `exploration_const`: The constant used in the UCB formula to balance exploration and exploitation.
  - `max_depth`: Maximum depth of the search tree.
  - `max_fallback`: Maximum number of fallback attempts to force the model to output a final answer.
  - `num_candidates`: Number of candidate final answers generated per node.
  
  These settings allow you to experiment with different reasoning strategies and trade-offs between search depth and the precision of the final answer.

- **Sample Selection:**  
  The `max_samples` parameter controls the number of problems from the dataset to evaluate. This lets you start by testing on a single sample and then scale the evaluation to larger subsets (e.g., 100 or more problems) as desired.

- **Final Answer Extraction and Fallback Mechanism:**  
  The evaluation process extracts only the final answer from the model’s output (using a helper like `extract_answer`), ensuring that only a concise final answer is compared against the expected solution. If no complete final answer is found within the search tree, a fallback prompt is issued to force the model to provide the final answer in the correct format.

- **Result Saving and Analysis:**  
  Each problem’s evaluation result—including the problem, the model’s raw response, the extracted final answer, and correctness—is saved to the results file. A summary report is then generated, which includes key metrics such as total evaluated problems, number of correct answers, and overall accuracy.

## Encouragement for Further Improvement

- **Experiment with the Exploration Constant:**  
  Try using different values for the exploration constant (e.g., 0.5, 1.41, 2.0) to see how they affect the balance between exploring new nodes and exploiting known promising ones. Compare the results and observe how the search tree structure and the final answer accuracy change with each setting.

- **Tuning Other Parameters:**  
  Experiment with other hyperparameters such as `num_simulations`, `max_depth`, and `num_candidates`. Adjusting these values can impact the thoroughness of the search and the likelihood of obtaining a complete final answer. 

- **Document Your Observations:**  
  As you tweak the parameters, please comment on your expectations and the outcomes:
  - What changes do you observe when you adjust the exploration constant?
  - How does increasing the maximum depth influence the quality and correctness of the final answer?
  - Does generating more candidates per node lead to more accurate results?
  
  Sharing your observations and comparing them with results from other approaches (such as A* or the Tree of Thoughts method) can provide valuable insights and help guide further improvements.

This framework is designed to be flexible, so feel free to adjust the parameters and prompts to suit your research needs and push the performance of the MCTS method further.


In [None]:
import os
from tqdm import tqdm

def evaluate_mcts(max_samples=1):
    """
    Evaluate the MCTS search method on a subset of Math500, using GPT/Gemini
    (via get_api_response) for any fallback final‑answer requests.
    Assumes these helpers exist:
      - load_math500_dataset()
      - extract_answer(text)
      - compare_answers(correct, pred)
      - save_result(filename, result_dict)
      - load_existing_results(filename)
      - analyze_results(results_list)
      - get_llm_response(prompt)   # for candidate generation
      - get_api_response(prompt)   # for GPT/Gemini verification/fallback
    """
    os.makedirs("results", exist_ok=True)
    results_file = "results/evaluation_results_mcts_test.json"
    if os.path.exists(results_file):
        os.remove(results_file)

    dataset = load_math500_dataset()
    existing = load_existing_results(results_file)
    processed = {r['index'] for r in existing}

    mcts = MCTSSearch(
        num_simulations=5,
        exploration_const=1.41,
        max_depth=1,
        max_fallback=3,
        num_candidates=1,
        verbose=True
    )

    correct_count = 0
    evaluated = 0

    def collect_all_nodes(node):
        nodes = [node]
        for child in node.children:
            nodes.extend(collect_all_nodes(child))
        return nodes

    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed or evaluated >= max_samples:
            continue

        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])

        # Run MCTS search
        solution_node = mcts.search(problem_text)

        if solution_node:
            response = solution_node.state
            print("DEBUG: Found solution node:", response)
        else:
            print("DEBUG: No solution node found. Using GPT fallback.")
            fallback_prompt = (
                f"Based on the problem: \"{problem_text}\", provide ONLY your final answer "
                "in the exact format \\boxed{<final answer>}. Do not include any chain‑of‑thought or explanation."
            )
            response = get_api_response(fallback_prompt).strip()
            print("DEBUG: Fallback response:", response)

        predicted_answer = extract_answer(response) or ""
        if not predicted_answer:
            print("DEBUG: predicted_answer is empty. Raw response was:\n", response)

        is_correct = compare_answers(correct_answer, predicted_answer)
        save_result(results_file, {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        })

        if is_correct:
            correct_count += 1
        evaluated += 1
        print(f"Correct: {correct_count}/{evaluated} | Index: {idx}")

    final_results = load_existing_results(results_file)
    analyze_results(final_results)

# Example: evaluate a single sample
evaluate_mcts(max_samples=30)
# To test more problems, increase max_samples


=== Results Summary ===

Total problems: 30

Correct answers: 18

Accuracy: 60.00%

Runtime: 1:04:14
