# Monte Carlo Tree Search (MCTS) for LLMs

Original paper can be found here: https://arxiv.org/html/2406.07394v2


## Setup


In [48]:
import ollama

# Import the Ollama to test the chat function
response = ollama.chat(
    model="llama3.2",
    messages=[
        {
            "role": "user",
            "content": "Tell me an interesting fact about elephants",
        },
    ],
)
print(response["message"]["content"])

Here's an interesting fact about elephants:

Elephants have a highly developed sense of empathy and can recognize and mourn the death of their loved ones. In fact, they exhibit behaviors that are similar to human grief, such as touching the graves of their deceased family members with their trunks or visiting places where their relatives used to live.

One famous example is the story of Sudan, an African elephant who was separated from his matriarchal herd during a poaching operation. He wandered for 22 years, searching for her, before finally finding her ghostly image in a memory that had been etched into his brain. When he found her remains, he touched her trunk with his trunk and stood there, holding the space where she used to be, as if trying to hold onto her.

This incredible display of emotional intelligence and empathy is a testament to the complexity and depth of elephant cognition.


In [None]:
MODEL = "llama3.1:8b"

def chat_with_ollama(prompt):
    
    messages = [
        {
            "role": "user",
            "content": prompt,
              "options": {
                "temperature": 1
                },
        },
    ]
    response = ollama.chat(model=MODEL, messages=messages)

    #Extract the response from the Ollama
    response_content = response["message"]["content"]
    if response_content:
        return response_content
    else:
        return None
    
def interactive_chat():
    print("Welcome to the Ollama chatbot! Type 'exit' to end the conversation.")
    while True:
        user_input = input("You: ")
        print("You: ", user_input)
        response = chat_with_ollama(user_input)
        print("Ollama: ", response)
        if user_input == "exit":
            break

    
if __name__ == "__main__":
    INTERACTIVE_MODE = False
    SINGLE_QUERY_MODE = True

    if INTERACTIVE_MODE:
        interactive_chat()
    
    if SINGLE_QUERY_MODE:
        prompt = "Count the number of words in this sentence"

        response = chat_with_ollama(prompt)

        print(f"Prompt: {prompt}\n")
        print(f"Response: {response}")
    

    

### Prepering the Starting Prompts


In [50]:
initial_answers = [
    "I don't know the answer",
    "I'm not certain",
    "I can't say for sure",
]

### Critiquing an Answer


In [None]:
def assess_response(question, preliminary_answer):
    prompt = (
    f"Question: {question}\n"
    f"Preliminary Answer: {preliminary_answer}\n"
    "Please assess the draft answer. "
    "Do a careful review of whether the answer is correct or not and why. "
    "Consider multiple ways of verifying the correctness of the answer. "
    "Point out every flaw and hold the draft answer to high standards. "
    "Provide specific recommendations on how to improve the answer. "
    "Think step by step. "
    "Do not provide a revised answer."
)
    return chat_with_ollama(prompt)

if __name__ == "__main__":
    question = "In what year did the Titanic sink?"
    preliminary_answer = "The Titanic sank in 1910."

    evaluated_response = assess_response(question, preliminary_answer)

    print(evaluated_response)


### Refining the Response


In [None]:
def refine_response(question, preliminary_answer, evaluation):
    prompt = (
    f"Question: {question}\n"
    f"Preliminary Answer: {preliminary_answer}\n"
    f"Evaluation: {evaluation}\n"
    "Please improve the draft answer based on the critique. Follow this format: \n"
    "Reasoning Process: <step-by-step reasoning process> \n"
    "Verification: <verification of the facts> \n"
    "Final Answer: <the improved answer and verified answer>"
)
    
    return chat_with_ollama(prompt)


if __name__ == "__main__":

    question = "In what year did the Titanic sink?"
    preliminary_answer = "The Titanic sank in 1910."
    evaluation = "The Titanic sank in 1912, not 1910. The answer is incorrect."

    refined_response = refine_response(question, preliminary_answer, evaluation)

    print(f"--Polished Response--\n{refined_response}")


# Rating Mechanism


In [None]:
import re

def grade_response(question, answer):
    prompt = (
    f"Question: {question}\n"
    f"Answer: {answer}\n"
    "As an expert on this topic, please provide a detailed critique of the answer. "
    "Provide only a critique, not a suggested answer. "
    "Then, rate the anwer on a scale of 0 to 100"
    "The response should be in the following format: \n"
    "Critique: <detailed critique of the answer> \n"
    "Rating: <rating from 0 to 100>"
    )

    graded_response = chat_with_ollama(prompt)

    try:
        match = re.search(r"Rating:\s*(\d+)", graded_response)
        if match:
            rating = int(match.group(1))
            if rating > 95:
                rating = 95
            
            rating = float(rating)/100
        else:
            raise ValueError("Rating not found in the response")
    except Exception as e:
        print(f"Error was encountered during grading: {e}")
        print(f"Response that caused the error was: {graded_response}")
        rating = 0

    return rating

if __name__ == "__main__":
    question = "In what year did the Titanic sink?"
    answer = "The Titanic sank in 1912."

    rating = grade_response(question, answer)

    print(f"\nRating: {rating}")

In [54]:
def direct_response(question):
    prompt = (
    f"Question: {question}\n"
    "Please provide the answer with detailed reasoning. Follow this format: \n"
    "Reasoning Process: <step-by-step reasoning process> \n"
    "Verification: <verification of the facts> \n"
    "Final Answer: <the improved answer and verified answer>\n"
    )
    ""

    ollama_response = chat_with_ollama(prompt)

    try:
        match = re.search(r"Final Answer:\s*(.*)", ollama_response, re.DOTALL)
        final_answer = match.group(1).strip() if match.group(1) else None
    except Exception as e:
        final_answer = None

    return ollama_response, final_answer

## Monte Carlo Tree Search


In [None]:
import math
import random
import numpy as np

max_children = 3

class Node:
    def __init__(self, question, answer, parent=None):
        self.question = question
        self.answer = answer
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0.0

    def is_fully_expanded(self):
        return len(self.children) >= max_children
    
    def best_child(self, exploration_weight=1.41):
        choices_weights = []
        for child in self.children:
            if child.visits == 0:
                weight = float("inf")
            else:
                weight = child.value / child.visits + exploration_weight * math.sqrt(
                    2 * math.log(self.visits) / child.visits
                )
            choices_weights.append(weight)
        return self.children[np.argmax(choices_weights)]
    
    def most_visited_child(self):
        return max(self.children, key=lambda child: child.visits)
    
    def add_child(self, child_node):
        self.children.append(child_node)

class MCTS:
    def __init__(self, question, initial_answers, iterations=2):
        self.question = question
        self.initial_answers = initial_answers
        self.iterations = iterations
        self.root = Node(question, random.choice(initial_answers))

    def search(self):
        for i in range(self.iterations):
            print(f"\nIteration {i+1}/{self.iterations}")
            node = self.select(self.root)
            print(f"Selected Node: {node.answer}")
            if not node.is_fully_expanded():
                node = self.expand(node)
                print(f"\nExpanded Node: {node.answer}")
            reward = self.simulate(node)
            print(f"\nReward: {reward}")
            self.backpropagate(node, reward)
        print(f"Visits to most visited child: {self.root.most_visited_child().visits}")
        return self.root.most_visited_child().answer
    
    def select(self, node):
        while node.is_fully_expanded() and node.children:
            node = node.best_child()
        return node
    
    def expand(self, node):
        for j in range(max_children -len(node.children)):
            child_node = Node(self.question, node.answer, parent=node)
            node.add_child(child_node)

            evaluation = assess_response(self.question, child_node.answer)
            print(f"\n--Evaluation {j}--\n{evaluation}")

            refined_answer = refine_response(self.question, child_node.answer, evaluation)
            print(f"\n--Refined Answer {j}--\n{refined_answer}")

            child_node.answer = refined_answer
        return random.choice(node.children)
    

    def simulate(self, node):
        rating = grade_response(self.question, node.answer)
        return rating
    
    def backpropagate(self, node, reward):
        while node is not None:
            node.visits += 1
            node.value += reward
            node = node.parent


if __name__ == "__main__":
    question = "In what year did the Titanic sink?"
    mcts = MCTS(question, initial_answers, iterations=2)
    best_answer = mcts.search()

    print(f"\nBest Answer: {best_answer}")

## Benchmarking on the MATH Dataset


In [56]:
import pandas as pd
from datasets import load_dataset
import re

def isolate_boxed_response(response):
    pattern = re.compile(r'\\boxed{((?:[^{}]|\{[^{}]*\})*)}')
    match = pattern.search(response)
    if match:
        return match.group(1)
    return None

def get_question_and_answer(row_number=None):
    dataset = load_dataset("lighteval/MATH", "all", split="test[:100]")
    df = pd.DataFrame(dataset)

    if row_number is None:
        row_number = int(input("Please enter the row you would like to use (0-99): "))

    if row_number < 0 or row_number > 99:
        raise ValueError("Row number must be between 0 and 99!")
    
    row = df.iloc[row_number]

    question = row["problem"]
    answer = row["solution"]
    short_answer = isolate_boxed_response(answer)

    return question, answer, short_answer    

In [57]:
dataset = load_dataset("lighteval/MATH", "all", split="test[:100]")
df = pd.DataFrame(dataset)
df.head()
level_5_questions = df[df["level"] == "Level 5"]
level_5_questions.head()

Unnamed: 0,problem,level,type,solution
3,Evaluate $i^5+i^{-25}+i^{45}$.,Level 5,Algebra,We have $i^5 = i^4\cdot i = 1\cdot (i) = i$. ...
9,What is the smallest value of $x$ such that $|...,Level 5,Algebra,"There are two cases, when $5x-1=3x+2$ and when..."
11,Find the sum of all integers that satisfy thes...,Level 5,Algebra,"First, let's deal with $|x| + 1 > 7$. Subtrac..."
15,What is the smallest real number $x$ in the do...,Level 5,Algebra,A real number $x$ is in the domain of $g$ if a...
20,Suppose the roots of the polynomial $x^2 - mx ...,Level 5,Algebra,"Let $p$ and $q$ be the prime roots. Then, we k..."


In [58]:
from IPython.display import display, Markdown

question, answer, short_answer = get_question_and_answer(20)

display(Markdown(f"**Question:** {question}"))

ollama_response, final_answer = direct_response(question)

display(Markdown(f"**Ollama Response:** {ollama_response}"))
display(Markdown(f"**Ground Truth Answer:** {answer}"))
display(Markdown(f"**Ground Truth Short Answer:** {short_answer}"))

**Question:** Suppose the roots of the polynomial $x^2 - mx + n$ are positive prime integers (not necessarily distinct). Given that $m < 20,$ how many possible values of $n$ are there?

**Ollama Response:** Reasoning Process:

1.  Since we're dealing with a quadratic equation $x^2 - mx + n$ where roots are positive prime integers, according to Vieta's formulas for quadratic equations, if $\alpha$ and $\beta$ are the roots of the polynomial $ax^2+bx+c$, then $c=\alpha\beta$. Here, our polynomial is in the form $x^2-mx+n$. So, here it is $\alpha \beta=n$. Thus, we know both roots are primes.
2.  We need to remember that since the sum of the roots equals $m$ (from Vieta's formulas again), and the roots must be prime numbers, then their product $\alpha\beta = n$. The only values for which this could possibly work out well with primes as roots is when both roots are smaller than $\frac{20}{2} = 10$.
3.  Given that our equation has two equal roots where the value of each root equals $p$, then Vieta's formulas tell us $m= 2p$ and $n=p^2$. The only primes less than ten are 2, 3, 5, and 7 (4 is not prime).
4.  Testing these values, we have four potential sets: $(2,2)$, $(3,3)$, $(5,5)$, and $(7,7)$. For each set, we know $m = 2p$ which equals $\{4,6,10,14\}$ in order. Since m must be less than twenty, all these work.
5.  The number of possible values of n can then be determined by adding the number of prime squared numbers that come from $(2)^2$, $(3)^2$, $(5)^2$, and $(7)^2$. Thus, $n=\{4,9,25,49\}$ or $\boxed{4}$ possible values for n.
6.  So after detailed calculations, the value of $m$ works out well because it is less than twenty, since each root was a prime number.

Verification: The statement that each root has to be prime and m must be less than 20 can be checked using the following examples:
   - For $(2,2)$, both roots are indeed primes. So, $m=4$ and $n=4$. Indeed, $m < 20.$
   - Similarly, for $(3,3)$, both roots are prime numbers, so $m = 6$, which is also less than twenty, and the product of these two prime numbers is 9.
   - For $(5,5)$, we get $m=10$ and $n=25$. Again, m <20.
   - For $(7,7)$, we get $m = 14$ and $n = 49$, which also satisfies the condition that $m<20.$
   Thus, all sets of root prime numbers satisfy this condition.

Final Answer: The final answer is $\boxed{4}$.

**Ground Truth Answer:** Let $p$ and $q$ be the prime roots. Then, we know that $m = p+q$ and $n = pq$. Since $m < 20$, the primes $p$ and $q$ must both be less than $20$.

The primes less than $20$ are $2,$ $3,$ $5,$ $7,$ $11,$ $13,$ $17,$ $19.$ Now we list all possible pairs $(p, q)$ such that $p + q < 20$, remembering to also include the cases in which $p=q$: \[\begin{aligned} & (2,2),(2,3),(2,5),(2,7),(2,11),(2,13),(2,17) \\
&(3,3),(3,5),(3,7),(3,11),(3,13) \\
&(5,5),(5,7),(5,11),(5,13) \\
&(7,7),(7,11) \end{aligned}\]There are $7 + 5 + 4 + 2 = 18$ pairs in total. Each pair produces a value for $n$, and furthermore, these values are all distinct, because every positive integer has a unique prime factorization. Therefore, there are $\boxed{18}$ possible values for $n$.

**Ground Truth Short Answer:** 18

In [60]:
question, answer, short_answer = get_question_and_answer(20)

display(Markdown(f"**Question:** {question}"))

mcts = MCTS(question, initial_answers, iterations=10)
best_answer = mcts.search()


display(Markdown(f"**MCTS Best Answer:** {best_answer}"))
display(Markdown(f"**Ground Truth Answer:** {answer}"))
display(Markdown(f"**Ground Truth Short Answer:** {short_answer}"))

**Question:** Suppose the roots of the polynomial $x^2 - mx + n$ are positive prime integers (not necessarily distinct). Given that $m < 20,$ how many possible values of $n$ are there?


Iteration 1/10
Selected Node: I'm not certain

--Evaluation 0--
I'll review your draft answer carefully.

**Assessment:**
Your preliminary answer is "I'm not certain." While it's good that you're being cautious, I'd like to see more analysis and reasoning behind this uncertainty.

**Step-by-Step Review:**

1. **Understanding the Problem:** The problem states that the roots of the polynomial $x^2 - mx + n$ are positive prime integers (not necessarily distinct). We need to find how many possible values of $n$ there are, given that $m < 20.$
2. **Recall Vieta's Formulas:** To relate the coefficients of the polynomial ($m$ and $n$) to its roots, recall that for a quadratic equation $ax^2 + bx + c = 0$, the sum of the roots is $-b/a$ and the product of the roots is $c/a$. In this case, $a=1$, so the sum of the roots is $m$ and the product of the roots is $n$.
3. **Considering Possible Prime Roots:** Since the roots are positive prime integers (not necessarily distinct), consider possible p

**MCTS Best Answer:** I'll thoroughly review my previous thought process, applying the recommendations to arrive at a confident answer.

**Revised Reasoning Process**

1. **Relate roots to m using Vieta's formulas**: For a quadratic equation $x^2 - mx + n = 0$, the sum of its roots is equal to $m$ and the product of its roots is equal to $n$. Since both roots are positive prime integers, I'll explore possible combinations that satisfy this condition.
2. **Explore prime integer pairs**: Given that each root is a positive prime integer (not necessarily distinct), consider potential pairs of primes $(p_1,p_2)$ such that their sum satisfies the constraint $m < 20$. Since there are only a few small primes, enumerate these combinations to identify possible values for $n$.
3. **Apply the condition m < 20**: With this constraint in mind, I'll limit my consideration of prime pairs $(p_1,p_2)$ where their sum is less than $20$, leading to corresponding values of $n=p_1 \cdot p_2$.

**Verification**

Upon analyzing possible combinations of positive prime integers for the roots, considering both distinct and non-distinct cases:

- For (2, 3), m=5, n = 6
- For (3, 3), m = 6, n = 9
- For (2, 7), m = 9, n = 14
- For (3, 5), m = 8, n = 15
- For (5, 5), m = 10, n = 25

**Counting possible n values**: Based on these combinations and their corresponding values of $n$, I'll count the distinct possibilities to find the final answer.

**Final Answer**

There are $\boxed{4}$ possible values for $n$ that satisfy all given conditions.

**Ground Truth Answer:** Let $p$ and $q$ be the prime roots. Then, we know that $m = p+q$ and $n = pq$. Since $m < 20$, the primes $p$ and $q$ must both be less than $20$.

The primes less than $20$ are $2,$ $3,$ $5,$ $7,$ $11,$ $13,$ $17,$ $19.$ Now we list all possible pairs $(p, q)$ such that $p + q < 20$, remembering to also include the cases in which $p=q$: \[\begin{aligned} & (2,2),(2,3),(2,5),(2,7),(2,11),(2,13),(2,17) \\
&(3,3),(3,5),(3,7),(3,11),(3,13) \\
&(5,5),(5,7),(5,11),(5,13) \\
&(7,7),(7,11) \end{aligned}\]There are $7 + 5 + 4 + 2 = 18$ pairs in total. Each pair produces a value for $n$, and furthermore, these values are all distinct, because every positive integer has a unique prime factorization. Therefore, there are $\boxed{18}$ possible values for $n$.

**Ground Truth Short Answer:** 18