# Topic 3: Tree-of-Thoughts, Graph-of-Thoughts, and Self-Ask

**Learning Objectives:**

* Check the codes and prompts of ToT, GoT, and Self-Ask
* How the prompts work on various tasks.

**Outline:**

1. **Tree-of-Thoughts**
2. **Graph-of-Thoughts**
3. **Self-Ask**

**Paper Links:**

1. **Tree-of-Thoughts**: [Link](https://arxiv.org/abs/2305.10601)
2. **Graph-of-Thoughts**: [Link](https://arxiv.org/abs/2308.09687)
3. **Self-Ask**: [Link](https://arxiv.org/abs/2210.03350)

### Environment Setup

In [None]:
# Install openai
!pip install -qU openai

In [None]:
# Install langchain
!pip install -qU langchain

In [None]:
# Install langchain-openai
!pip install -qU langchain-openai

**Set API key**

In [None]:
# Set API key
OPENAI_API_KEY="your_api_key_here"

## 1. Tree-of-Thoughts

We will test prompts with simplified version of Tree-of-Thoughts code.

Original code source: [Link](https://github.com/princeton-nlp/tree-of-thought-llm)

Simplified code source: [Link](https://github.com/ayushtues/tot_from_scratch/tree/main)

### Using Language Models

In [None]:
# Prepare model
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.7, api_key=OPENAI_API_KEY,  max_tokens=1000)

### Define Output Parser

In [None]:
# Import simple output parser
from langchain_core.output_parsers import StrOutputParser

parser = StrOutputParser()

### Define Prompt Templates

In [None]:
from langchain_core.prompts import PromptTemplate

In [None]:
propose_prompt = PromptTemplate(
    input_variables=["state"],
    template="""
    Your goal is to use the given numbers and the basic arithmetic operations (+, -, *, /) to obtain the number 24.
    You can use each number only once, but you can use the operations in any order and as many times as you want.
    This task will take multiple steps. For the current step, you choose two numbers and perform an arithmetic operation on them.

    Examples
    Input: 2 8 8 14
    Possible next steps:
    Output1: 2 + 8 = 10 (left: 8 10 14)
    Output2: 8 / 2 = 4 (left: 4 8 14)
    Output3: 14 + 2 = 16 (left: 8 8 16)
    Output4: 2 * 8 = 16 (left: 8 14 16)

    Input: 4 10 12 1
    Possible next steps:
    Output1: 12 - 10 = 2 (left: 4 2 1)
    Output2: 4 * 10 = 40 (left: 40 12 1)
    Output3: 12 + 1 = 13 (left: 4 10 13)
    Output4 12/4 = 3 (left: 3 10 1)

    Now for the below input
    Input: {state}
    Possible next steps:"
    """
)

In [None]:
eval_prompt = PromptTemplate(
    input_variables=["proposal"],
    template="""
    Evaluate if given numbers can reach 24 using basic arithmetic operations (+, -, *, /).
    You must use each number only once, but you can use the operations in any order and as many times as you want.

    Some examples are:
    Input: 10, 14 -> 10 + 14 = 24. -> Output: "sure"
    Input: 4, 9, 10, 13 -> (10 - 4) * (13 - 9) = 24. -> Output: "sure"
    Input: 20, 10: Not possible -> Output: "impossible"

    Can the numbers {proposal} reach 24?
    """
)

In [None]:
# Example usage of propose_prompt
state = "2 8 8 14"

propose_chain = propose_prompt | llm | parser

proposal = propose_chain.invoke({"state": state})
print(proposal)

In [None]:
def extract_proposals(text):
    text = text.split("\n")

    text = [item for item in text if "Output" in item]

    proposals = []
    for x in text:
        x = x.lower()
        x = x.split("left:")
        if len(x) == 2 :
            x = x[1].split(')')[0]
            proposals.append(x)
    return proposals

In [None]:
proposals = extract_proposals(proposal)

In [None]:
proposals

In [None]:
# Example usage of propose_prompt
proposal = "2, 8, 6"

eval_chain = eval_prompt | llm | parser

eval = eval_chain.invoke({"proposal": proposal})
print(eval)

In [None]:
def extract_evaluation(text):
    text  = text.lower()
    if "impossible" in text:
        return 0
    elif "sure" in text:
        return 1
    else:
        return 0.5

In [None]:
extract_evaluation(eval)

### Run Tree-of-Thoughts

In [None]:
curr_states = ["6 7 9 9"]

TREE_DEPTH = 3
PROPOSAL_RUNS_PER_STATE = 2
EVAL_RUNS_PER_STATE = 1
BRANCH_FACTOR = 2

for depth in range(TREE_DEPTH):
    proposal_and_score = []
    for state in curr_states:
        proposals = []
        for _ in range(PROPOSAL_RUNS_PER_STATE):
            proposals += extract_proposals(propose_chain.invoke({"state": state}))
            print("current proposals:", proposals)

        for proposal in proposals:
            score = 0
            for _ in range(EVAL_RUNS_PER_STATE):
                score += extract_evaluation(eval_chain.invoke({"proposal": proposal}))

            proposal_and_score.append((proposal, score/EVAL_RUNS_PER_STATE))
            print("current proposal_and_score: ", proposal_and_score)

    # sort proposals by score
    proposal_and_score.sort(key=lambda x: x[1], reverse=True)
    curr_states = [item[0] for item in proposal_and_score[:BRANCH_FACTOR]]

print(curr_states[0])

## 2. Graph-of-Thoughts

We will test prompts with simplified version of Graph-of-Thoughts.
Please check the GoT github repo for the original code.

Original code source:
[Link](https://github.com/spcl/graph-of-thoughts)

### Define Prompt Templates

In [None]:
split_prompt = PromptTemplate(
    input_variables=["input"],
    template="""
<Instruction> Split the following list of 32 numbers into 2 lists of 16 numbers each, the first list should contain the first 16 numbers and the second list the second 16 numbers.
Only output the final 2 lists in the following format without any additional text or thoughts!:
    "List 1": [3, 4, 3, 5, 7, 8, 1, ...]
    "List 2": [2, 9, 2, 4, 7, 1, 5, ...]
</Instruction>

<Example>
Input: [9, 6, 7, 7, 2, 0, 2, 2, 3, 5, 0, 9, 2, 2, 4, 4, 5, 2, 5, 1, 2, 8, 3, 8, 3, 9, 6, 0, 4, 2, 2, 3]
Output:
    "List 1": [9, 6, 7, 7, 2, 0, 2, 2, 3, 5, 0, 9, 2, 2, 4, 4]
    "List 2": [5, 2, 5, 1, 2, 8, 3, 8, 3, 9, 6, 0, 4, 2, 2, 3]

</Example>

Input: {input}
Output: """
)

sort_prompt =  PromptTemplate(
    input_variables=["input"],
    template="""
<Instruction> Sort the following list of numbers in ascending order. Output only the sorted list of numbers, no additional text. </Instruction>

<Examples>
Input: [5, 1, 0, 1, 2, 0, 4, 8, 1, 9, 5, 1, 3, 3, 9, 7]
Output: [0, 0, 1, 1, 1, 1, 2, 3, 3, 4, 5, 5, 7, 8, 9, 9]

Input: [3, 7, 0, 2, 8, 1, 2, 2, 2, 4, 7, 8, 5, 5, 3, 9, 4, 3, 5, 6, 6, 4, 4, 5, 2, 0, 9, 3, 3, 9, 2, 1]
Output: [0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 9]

Input: [4, 4, 9, 7, 9, 7, 0, 0, 4, 9, 1, 7, 9, 5, 8, 7, 5, 6, 3, 8, 6, 7, 5, 8, 5, 0, 6, 3, 7, 0, 5, 3, 7, 5, 2, 4, 4, 9, 0, 7, 8, 2, 7, 7, 7, 2, 1, 3, 9, 9, 7, 9, 6, 6, 4, 5, 4, 2, 0, 8, 9, 0, 2, 2]
Output: [0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9]
</Examples>

Input: {input}
Output: """
)

merge_prompt = PromptTemplate(
    input_variables=["input1", "input2", "length1", "length2"],
    template="""
<Instruction> Merge the following 2 sorted lists of length {length1} each, into one sorted list of length {length2} using a merge sort style approach.
Only output the final merged list without any additional text or thoughts!:</Instruction>

<Approach>
To merge the two lists in a merge-sort style approach, follow these steps:
1. Compare the first element of both lists.
2. Append the smaller element to the merged list and move to the next element in the list from which the smaller element came.
3. Repeat steps 1 and 2 until one of the lists is empty.
4. Append the remaining elements of the non-empty list to the merged list.
</Approach>

Merge the following two lists into one sorted list:
1: {input1}
2: {input2}

Merged list: """
)

value_prompt = PromptTemplate(
    input_variables=["input", "variant"],
    template="""
<Instruction> The following two lists represent an input list of numbers and a variant of that list. Evaluate if the variant is correctly sorted, with respect to the input list.
Answer 1 if the variant list is correctly sorted, deduct 0.1 for the each errors in the variant.
Only output the value without any additional text or thoughts!:</Instruction>

<Approach>
To score the variant list follow these steps:
1. For each number from 0 to 9, compare the frequency of that number in the variant list to the frequency of that number in the input list.
2. Iterate through the variant list and add or remove numbers as needed to make the frequency of each number in the variant list match the frequency of that number in the input list.
3. Count the number of errors in the variant list and deduct from 1.
</Approach>

<Examples>
Input: [3, 7, 0, 2, 8, 1, 2, 2, 2, 4, 7, 8, 5, 5, 3, 9]
Variant: [0, 0, 0, 0, 0, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 7, 7, 8, 8, 9, 9, 9, 9]
Reason: The variant list contains four extra 0s, two extra 4s and three extra 9s and is missing two 2s.
Output: 0

Input: [6, 4, 5, 7, 5, 6, 9, 7, 6, 9, 4, 6, 9, 8, 1, 9, 2, 4, 9, 0, 7, 6, 5, 6, 6, 2, 8, 3, 9, 5, 6, 1]
Variant: [0, 1, 1, 2, 2, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 8, 8, 9, 9, 9, 9, 9]
Reason: The variant list contains two extra 4s and is missing two 6s and one 9.
Output: 0.5

Input: [4, 4, 9, 7, 9, 7, 0, 0, 4, 9, 1, 7, 9, 5, 8, 7, 5, 6, 3, 8, 6, 7, 5, 8, 5, 0, 6, 3, 7, 0, 5, 3, 7, 5, 2, 4, 4, 9, 0, 7, 8, 2, 7, 7, 7, 2, 1, 3, 9, 9, 7, 9, 6, 6, 4, 5, 4, 2, 0, 8, 9, 0, 2, 2]
Incorrectly Sorted: [0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9]
Reason: The variant list contains one extra 8 and is missing two 2s, one 3, three 4s, two 5s, one 6, six 7s and one 9.
Output: 0
</Examples>

Input: {input}
Variant: {variant}

Merged list: """
)

improve_prompt = PromptTemplate(
    input_variables=["input","incorrectly_sorted","length"],
    template="""
<Instruction> The following two lists represent an unsorted list of numbers and a sorted variant of that list. The sorted variant is not correct. Fix the sorted variant so that it is correct.
Make sure that the output list is sorted in ascending order, has the same number of elements as the input list ({length}), and contains the same elements as the input list. </Instruction>

<Approach>
To fix the incorrectly sorted list follow these steps:
1. For each number from 0 to 9, compare the frequency of that number in the incorrectly sorted list to the frequency of that number in the input list.
2. Iterate through the incorrectly sorted list and add or remove numbers as needed to make the frequency of each number in the incorrectly sorted list match the frequency of that number in the input list.
</Approach>

<Examples>
Input: [3, 7, 0, 2, 8, 1, 2, 2, 2, 4, 7, 8, 5, 5, 3, 9]
Incorrectly Sorted: [0, 0, 0, 0, 0, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 7, 7, 8, 8, 9, 9, 9, 9]
Reason: The incorrectly sorted list contains four extra 0s, two extra 4s and three extra 9s and is missing two 2s.
Output: [0, 1, 2, 2, 2, 2, 3, 3, 4, 5, 5, 7, 7, 8, 8, 9]

Input: [6, 4, 5, 7, 5, 6, 9, 7, 6, 9, 4, 6, 9, 8, 1, 9, 2, 4, 9, 0, 7, 6, 5, 6, 6, 2, 8, 3, 9, 5, 6, 1]
Incorrectly Sorted: [0, 1, 1, 2, 2, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 8, 8, 9, 9, 9, 9, 9]
Reason: The incorrectly sorted list contains two extra 4s and is missing two 6s and one 9.
Output: [0, 1, 1, 2, 2, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 8, 8, 9, 9, 9, 9, 9, 9]

Input: [4, 4, 9, 7, 9, 7, 0, 0, 4, 9, 1, 7, 9, 5, 8, 7, 5, 6, 3, 8, 6, 7, 5, 8, 5, 0, 6, 3, 7, 0, 5, 3, 7, 5, 2, 4, 4, 9, 0, 7, 8, 2, 7, 7, 7, 2, 1, 3, 9, 9, 7, 9, 6, 6, 4, 5, 4, 2, 0, 8, 9, 0, 2, 2]
Incorrectly Sorted: [0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9]
Reason: The incorrectly sorted list contains one extra 8 and is missing two 2s, one 3, three 4s, two 5s, one 6, six 7s and one 9.
Output: [0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9]
</Examples>

Input: {input}
Incorrectly Sorted: {incorrectly_sorted}
"""
)

### Create random 32 numbers list

In [None]:
import random

def create_random_list(length, min_val, max_val):
  """Creates a list of random integers within a specified range.

  Args:
    length: The desired length of the list.
    min_val: The minimum value for the random integers.
    max_val: The maximum value for the random integers.

  Returns:
    A list of random integers.
  """
  return [random.randint(min_val, max_val) for _ in range(length)]

random_list = create_random_list(32, 0, 9)
print(random_list)

### Split to two lists

In [None]:
split_chain = split_prompt | llm | parser
split_result = split_chain.invoke({"input": str(random_list)})
print(split_result)

def extract_lists(text):
    text = text.lower()
    text = text.split("\n")

    text = [item for item in text if "list" in item]

    lists = []
    for x in text:
        x = x.split(":")
        if len(x) == 2 :
            x = x[1]
            lists.append(x)
    return lists

list1_str, list2_str = extract_lists(split_result)[0], extract_lists(split_result)[1]

print(list1_str,list2_str)

### Sort each list and repeat two times

In [None]:
sort_chain = sort_prompt | llm | parser
sort_result1_1 = sort_chain.invoke({"input": list1_str})
sort_result1_2 = sort_chain.invoke({"input": list1_str})
sort_result2_1 = sort_chain.invoke({"input": list2_str})
sort_result2_2 = sort_chain.invoke({"input": list2_str})

print(sort_result1_1, sort_result1_2, sort_result2_1, sort_result2_2)

### Merge two lists into one list

In [None]:
merge_chain = merge_prompt | llm | parser
merge_result1 = merge_chain.invoke({"input1": sort_result1_1, "input2": sort_result2_1, "length1": 16, "length2": 32})
merge_result2 = merge_chain.invoke({"input1": sort_result1_2, "input2": sort_result2_1, "length1": 16, "length2": 32})
merge_result3 = merge_chain.invoke({"input1": sort_result1_1, "input2": sort_result2_2, "length1": 16, "length2": 32})
merge_result4 = merge_chain.invoke({"input1": sort_result1_2, "input2": sort_result2_2, "length1": 16, "length2": 32})
print(merge_result1, merge_result2, merge_result3, merge_result4)

### Value each merge result and keep one

In [None]:
value_chain = value_prompt | llm | parser
value_result1 = value_chain.invoke({"input": random_list, "variant": merge_result1})
value_result2 = value_chain.invoke({"input": random_list, "variant": merge_result2})
value_result3 = value_chain.invoke({"input": random_list, "variant": merge_result3})
value_result4 = value_chain.invoke({"input": random_list, "variant": merge_result4})
print(value_result1, value_result2, value_result3, value_result4)

In [None]:
value_results = [
    (merge_result1, float(value_result1)),
    (merge_result2, float(value_result2)),
    (merge_result3, float(value_result3)),
    (merge_result4, float(value_result4))
]

# Sort the results by the value in descending order
value_results.sort(key=lambda x: x[1], reverse=True)

# Select the merge result with the highest value
best_merge_result = value_results[0][0]

print("Best Merge Result:", best_merge_result)

### Refine the best thought

In [None]:
refine_chain= improve_prompt | llm | parser
refine_result1 = refine_chain.invoke({"input": random_list, "incorrectly_sorted": best_merge_result, "length": 32})
refine_result2 = refine_chain.invoke({"input": random_list, "incorrectly_sorted": best_merge_result, "length": 32})
print("Refine result 1: ", refine_result1,"\n\n\nRefine result 2: ", refine_result2)

### Compare with the correctly sorted list

In [None]:
sorted_list = sorted(random_list)
print(sorted_list)

## 3. Self-Ask

Code source: [Link](https://github.com/ofirpress/self-ask)

### Test Search Engine

We will use duckduckgo for the search engine.

In [None]:
!pip install -qU duckduckgo_search

In [None]:
from duckduckgo_search import DDGS

# Example usage:
query = "What is the capital of France?"
results = DDGS().text(query, max_results=1)

for result in results:
  print(result['title'])
  print(result['body'])
  print("----")

### Define prompt with 4 few-shot examples

In [None]:
self_ask_prompt = ['''Question: Who lived longer, Muhammad Ali or Alan Turing?
Are follow up questions needed here: Yes.
Follow up: How old was Muhammad Ali when he died?
Intermediate answer: Muhammad Ali was 74 years old when he died.
Follow up: How old was Alan Turing when he died?
Intermediate answer: Alan Turing was 41 years old when he died.
So the final answer is: Muhammad Ali

Question: When was the founder of craigslist born?
Are follow up questions needed here: Yes.
Follow up: Who was the founder of craigslist?
Intermediate answer: Craigslist was founded by Craig Newmark.
Follow up: When was Craig Newmark born?
Intermediate answer: Craig Newmark was born on December 6, 1952.
So the final answer is: December 6, 1952

Question: Who was the maternal grandfather of George Washington?
Are follow up questions needed here: Yes.
Follow up: Who was the mother of George Washington?
Intermediate answer: The mother of George Washington was Mary Ball Washington.
Follow up: Who was the father of Mary Ball Washington?
Intermediate answer: The father of Mary Ball Washington was Joseph Ball.
So the final answer is: Joseph Ball

Question: Are both the directors of Jaws and Casino Royale from the same country?
Are follow up questions needed here: Yes.
Follow up: Who is the director of Jaws?
Intermediate Answer: The director of Jaws is Steven Spielberg.
Follow up: Where is Steven Spielberg from?
Intermediate Answer: The United States.
Follow up: Who is the director of Casino Royale?
Intermediate Answer: The director of Casino Royale is Martin Campbell.
Follow up: Where is Martin Campbell from?
Intermediate Answer: New Zealand.
So the final answer is: No

Question: ''',
'''
Are follow up questions needed here:''']


In [None]:
from openai import OpenAI

client = OpenAI(api_key=OPENAI_API_KEY)

### Define Functions

In [None]:
def promptf(question, prompt, intermediate = "\nIntermediate answer:", followup = "Follow up:", finalans= '\nSo the final answer is:'):
    cur_prompt = prompt[0] +  question + prompt[1]

    print(cur_prompt, end ='')

    ret_text = call_gpt(cur_prompt, intermediate)

    while followup in get_last_line(ret_text):


      cur_prompt += ret_text
      question = extract_question(ret_text)
      external_answer = get_answer(question)

      if external_answer is not None:
        cur_prompt += intermediate + ' ' + external_answer + '.'
        print(intermediate + ' ' + yellowfy(external_answer) + '.', end='' )
        ret_text = call_gpt(cur_prompt, intermediate)
      else:
        #We only get here in the very rare case that search engine returns no answer.
        cur_prompt += intermediate
        print(intermediate + ' ')
        gpt_answer = call_gpt(cur_prompt, ['\n'+followup, finalans])
        cur_prompt += gpt_answer


    if finalans not in ret_text:
      cur_prompt += finalans
      print(finalans, end = '')
      ret_text = call_gpt(cur_prompt, '\n')

    return cur_prompt + ret_text

def get_answer(question):
  results = DDGS().text(question, max_results=1)
  return results[0]['body']

def call_gpt(cur_prompt, stop):
  completion = client.chat.completions.create(
    model="gpt-4o-mini",
    messages=[
      {"role": "system", "content": "You are a helpful assistant."},
      {
          "role": "user",
          "content": cur_prompt
      }
    ],
    max_tokens=256,
    temperature=0.0,
    stop=stop
  )

  returned = completion.choices[0].message.content
  print(greenify(returned), end='')
  return returned


def extract_answer(generated):
    if '\n' not in generated:
        last_line =  generated
    else:
        last_line = generated.split('\n')[-1]

    if ':' not in last_line:
        after_colon = last_line
    else:
        after_colon = generated.split(':')[-1]

    if ' ' == after_colon[0]:
        after_colon = after_colon[1:]
    if '.' == after_colon[-1]:
        after_colon = after_colon[:-1]

    return after_colon

def extract_question(generated):
    if '\n' not in generated:
        last_line =  generated
    else:
        last_line = generated.split('\n')[-1]

    if 'Follow up:' not in last_line:
      print('we probably should never get here...' + generated)

    if ':' not in last_line:
        after_colon = last_line
    else:
        after_colon = generated.split(':')[-1]

    if ' ' == after_colon[0]:
        after_colon = after_colon[1:]
    if '?' != after_colon[-1]:
        print('we probably should never get here...' + generated)

    return after_colon

def get_last_line(generated):
    if '\n' not in generated:
        last_line =  generated
    else:
        last_line = generated.split('\n')[-1]


    return last_line

def greenify(input):
  return "\x1b[102m" + input + "\x1b[0m"

def yellowfy(input):
  return "\x1b[106m" + input + "\x1b[0m"

### Test Self-Ask

In [None]:
question = "What is the hometown of the reigning men's U.S. Open champion?"

ret = promptf(question, self_ask_prompt)
clean_ans = extract_answer(ret)

In [None]:
question = "Who was president of U.S. when semiconductor was discovered?"

ret = promptf(question, self_ask_prompt)
clean_ans = extract_answer(ret)