In [1]:
from __future__ import annotations
"""
Implements MCTS-FD based on the MCTS + Self-Refine algorithm from
`Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report`
by Zhang et. al.

"""


import random
import math
from collections import deque
from enum import Enum
from pydantic import BaseModel
import tqdm
import numpy as np

ROOT_UCT_SCORE = 10_000


class MCTSNode(BaseModel):
    answer: str
    parent: MCTSNode | None = None
    children: list[MCTSNode] = []
    visits: int = 0
    Q: float = 0
    reward_samples: list[int] = []

    def add_child(self, child_node: MCTSNode):
        self.children.append(child_node)

    def __repr__(self):
        return f"MCTSNode(answer={self.answer}, Q={self.Q:.2f}, visits={self.visits})"

    def add_reward(self, reward: int):
        self.reward_samples.append(reward)
        avg_reward = np.mean(self.reward_samples)
        min_reward = np.min(self.reward_samples)

        # Average worst-case and average outcomes
        self.Q = (min_reward + avg_reward) / 2


class SelectionPolicy(Enum):
    GREEDY = 1
    IMPORTANCE_SAMPLING = 2
    PAIRWISE_IMPORTANCE_SAMPLING = 3


class InitializeStrategy(Enum):
    ZERO_SHOT = 1
    DUMMY_ANSWER = 2


class MCTSr(BaseModel):
    problem: str
    max_rollouts: int
    exploration_constant: float = 1.0
    max_children: int = 3
    epsilon: float = 1e-10
    reward_limit: int = 95
    excess_reward_penalty: int = 5
    selection_policy: SelectionPolicy = SelectionPolicy.IMPORTANCE_SAMPLING
    initialize_strategy: InitializeStrategy = InitializeStrategy.ZERO_SHOT

    root: MCTSNode = MCTSNode(answer="I don't know.")

    # Logs
    critiques: list[str] = []
    refinements: list[str] = []
    rewards: list[float] = []
    selected_nodes: list[MCTSNode] = []
    errors: list[str] = []

    def self_refine(self, node: MCTSNode) -> MCTSNode:
        raise NotImplementedError()

    def _evaluate_answer(self, node: MCTSNode) -> int:
        raise NotImplementedError()

    def self_evaluate(self, node: MCTSNode):
        """Evaluate the quality of the answer. Sample `num_samples` times and average the results."""
        reward = self._evaluate_answer(node)

        if reward > self.reward_limit:
            reward -= self.excess_reward_penalty

        node.add_reward(reward)

    def backpropagate(self, node: MCTSNode):
        parent = node.parent
        while parent:
            best_child_Q = max(child.Q for child in parent.children)
            parent.Q = (parent.Q + best_child_Q) / 2
            parent.visits += 1
            parent = parent.parent

    def uct(self, node: MCTSNode):
        if not node.parent:
            # Using an arbitrarily high UCT score for the root node.
            # helps to prioritize breadth.
            return ROOT_UCT_SCORE

        return node.Q + self.exploration_constant * math.sqrt(
            math.log(node.parent.visits + 1) / (node.visits + self.epsilon)
        )

    def is_fully_expanded(self, node: MCTSNode):
        return len(node.children) >= self.max_children or any(
            child.Q > node.Q for child in node.children
        )

    def select_node(self):
        """Select a non-fully expanded node with the highest UCT value.

        A node is fully expanded if either:
        1. It has reached the max number of children
        2. Any of its children have a Q value greater than its own
        """
        candidates: list[MCTSNode] = []
        to_consider = deque([self.root])

        while to_consider:
            current_node = to_consider.popleft()
            if not self.is_fully_expanded(current_node):
                candidates.append(current_node)
            to_consider.extend(current_node.children)

        if not candidates:
            return self.root

        if self.selection_policy == SelectionPolicy.GREEDY:
            return max(candidates, key=self.uct)
        elif self.selection_policy == SelectionPolicy.IMPORTANCE_SAMPLING:
            # Sample, weighted by UCT score
            uct_scores = [self.uct(node) for node in candidates]
            selected_pair_idx = random.choices(
                range(len(candidates)), weights=uct_scores, k=1
            )[0]
            return candidates[selected_pair_idx]
        elif self.selection_policy == SelectionPolicy.PAIRWISE_IMPORTANCE_SAMPLING:
            # Sample, weighted by the difference in UCT scores between pairs
            uct_scores = [self.uct(node) for node in candidates]
            pairs = [
                (i, j) for i in range(len(candidates)) for j in range(len(candidates))
            ]
            pair_weights = [
                max(uct_scores[i], uct_scores[j]) - min(uct_scores[i], uct_scores[j])
                for i, j in pairs
            ]
            selected_pair_idx = random.choices(
                range(len(pairs)), weights=pair_weights, k=1
            )[0]
            selected_candidate_idx = max(
                pairs[selected_pair_idx], key=lambda x: uct_scores[x]
            )
            return candidates[selected_candidate_idx]
        else:
            raise ValueError(f"Invalid selection policy: {self.selection_policy}")

    def zero_shot(self) -> str:
        """Generate a zero-shot answer."""
        raise NotImplementedError()

    def initialize(self):
        """Generate a zero-shot answer."""
        if self.initialize_strategy == InitializeStrategy.ZERO_SHOT:
            self.root = MCTSNode(answer=self.zero_shot())
        elif self.initialize_strategy == InitializeStrategy.DUMMY_ANSWER:
            self.root = MCTSNode(answer="I don't know.")
        else:
            raise ValueError(f"Invalid initialize strategy: {self.initialize_strategy}")

    def run(self):
        self.initialize()
        for _ in tqdm.tqdm(range(self.max_rollouts)):
            node = self.select_node()
            self.self_evaluate(node)
            child = self.self_refine(node)
            node.add_child(child)
            self.self_evaluate(child)
            self.backpropagate(child)

        return self.get_best_answer()

    def get_best_answer(self):
        from collections import deque

        to_visit = deque([self.root])
        best_node = self.root

        while to_visit:
            current_node = to_visit.popleft()
            if current_node.Q > best_node.Q:
                best_node = current_node
            to_visit.extend(current_node.children)

        return best_node.answer

    def print(self):
        print_tree(self.root)




class MCTS_FD(MCTSr):
    def zero_shot(self) -> str:
        response = openai_chat_completion(
            messages=[
                {
                    "role": "system",
                    "content": "The user will provide a problem. Solve the problem. Think step by step.",
                },
                {
                    "role": "user",
                    "content": f"<problem>\n{self.problem}\n</problem>",
                },
            ],
            model=gpt_4o_prompt_config.model,
            max_tokens=4000,
        )
        assert response.choices[0].message.content is not None
        return response.choices[0].message.content

    def self_refine(self, node: MCTSNode) -> MCTSNode:

        error = build_critique(node.answer)

        self.errors.append(error)
        
        critique_response = openai_chat_completion(
            messages=[
                {
                    "role": "system",
                    "content": gpt_4o_prompt_config.critic_system_prompt,
                },
                {
                    "role": "user",
                    "content": "\n\n".join(
                        [
                            f"<problem>\n{self.problem}\n</problem>",
                            f"<current_answer>\n{node.answer}\n</current_answer>",
                        ]
                    ),
                },
            ],
            model=gpt_4o_prompt_config.model,
            #added for test
            response_format={"type": "json_object"},
            max_tokens=4000,
        )
        critique = critique_response.choices[0].message.content
        assert critique is not None
        self.critiques.append(critique)

        refined_answer_response = openai_chat_completion(
            messages=[
                {
                    "role": "system",
                    "content": gpt_4o_prompt_config.refine_system_prompt,
                },
                {
                    "role": "user",
                    "content": "\n\n".join(
                        [
                            f"<problem>\n{self.problem}\n</problem>",
                            f"<current_answer>\n{node.answer}\n</current_answer>",
                            f"<critique>\n{critique}\n</critique>",
                            f"<known_errors>\n{error}\n</known_errors>",
                        ]
                    ),
                },
            ],
            model=gpt_4o_prompt_config.model,
            max_tokens=4000,
            response_format={"type": "json_object"},
        )
        refined_answer = RefineResponse.model_validate_json(
           refined_answer_response.choices[0].message.content  
        )

        
        self.refinements.append(refined_answer)

        return MCTSNode(
            answer=f"# Thought {refined_answer.thought}\n\n# Answer\n{refined_answer.answer}",
            parent=node,
        )


    def _evaluate_answer(self, node: MCTSNode) -> int:

        return calculate_network_quality(correct_json_format(parse_json(node.answer)))

def print_tree(node: MCTSNode | None, level: int = 0):
    if node is None:
        return
    indent = " " * level * 2
    node_str = repr(node)
    for line in node_str.split("\n"):
        print(indent + line)
    for child in node.children:
        print_tree(child, level + 1)

### ERROR DETECTION

In [8]:
def build_critique(answer):
    system = parse_json(str(answer))
    system = correct_json_format(system)
    #system  = eval(system)
    errors = get_errors(system)
    smalls = find_small_groups_with_connections(system)
    errors = str(errors) + "Functions that are unconnected:" + str(smalls)
    return errors

In [9]:
import json
def parse_json(unclean):
    
    start = 0
    end = len(unclean)
    
    for count,i  in enumerate(unclean):
        if i == "{":
            start = count
            break
            
    for count,i  in enumerate(unclean[::-1]):
        if i == "}":
            end = len(unclean) - count
            break    
            
    return unclean[start:end]

In [10]:
def find_small_groups_with_connections(system):
    G = nx.Graph()
    
    # Add nodes and edges
    for func in system["functions"]:
        G.add_node(func)
    for connection_type in ["energy_connections", "material_connections", "information_connections"]:
        for connection in system[connection_type]:
            G.add_edge(connection[0], connection[1])
    
    # Identify all connected components
    components = list(nx.connected_components(G))
    
    # Define threshold for "small" groups (fewer than half the nodes)
    total_nodes = len(G.nodes)
    small_groups = [component for component in components if len(component) < total_nodes / 2]
    
    # For each small group, find internal connections
    small_groups_with_connections = {
        tuple(group): [(u, v) for u, v in G.edges if u in group and v in group]
        for group in small_groups
    }
    
    return small_groups_with_connections

In [11]:
def get_errors(system):

    error = []

    cons = count_incoming_outgoing_connections(system)
    k = list(cons.keys())
   
    for i in k:
       
        if "Store" in i:
            if cons[i]['incoming']+cons[i]['outgoing'] > 1:
                error.append(f"{i} has to many connetions, should be 1 for Store")
        elif "Add" in i:
            if cons[i]['incoming'] > 2:
                error.append(f"{i} has to many incoming connetions, should be 2 for Add")
            if cons[i]['incoming'] < 2:
                error.append(f"{i} is missing a incoming connetion,should be 2 for Add")
            if cons[i]['outgoing'] > 1:
                error.append(f"{i} has to many out going connetions,should be 1 for Add ")
        elif "Seperate" in i:
            if cons[i]['incoming'] > 1:
                error.append(f"{i} has to many incoming connetions, should be 1 for Seperate ")
            if cons[i]['outgoing'] < 2:
                error.append(f"{i} is missing a outgoing connetion,should be 2 for Seperate")
            if cons[i]['outgoing'] > 2:
                error.append(f"{i} has to many out going connetions, should be 2 for Seperate ")
        elif "Convert" in i:
            if cons[i]['incoming'] > 1:
                error.append(f"{i} has to many incoming connetions, should be 1 for Convert ")
            if cons[i]['incoming'] < 1:
                error.append(f"{i} is missing a incoming connetion, should be 1 for Convert ")
            if cons[i]['outgoing'] > 1:
                error.append(f"{i} has to many outgoing connetions, should be 1 for Convert ")
            if cons[i]['outgoing'] < 1:
                error.append(f"{i} is missing a outgoing connetion, should be 1 for Convert ")
        elif "Guide" in i:
            if cons[i]['incoming'] > 1:
                error.append(f"{i} has to many incoming connetions, should be 1 for Guide ")
            if cons[i]['incoming'] < 1:
                error.append(f"{i} is missing a incoming connetion, should be 1 for Guide ")
            if cons[i]['outgoing'] > 1:
                error.append(f"{i} has to many outgoing connetions, should be 1 for Guide ")
            if cons[i]['outgoing'] < 1:
                error.append(f"{i} is missing a outgoing connetion, should be 1 for Guide ")

    return error

In [12]:
import networkx as nx
def count_incoming_outgoing_connections(system):
    # Use a directed graph to differentiate incoming and outgoing connections
    G = nx.DiGraph()
    
    # Add nodes
    for func in system["functions"]:
        G.add_node(func)
    
    # Add directed edges for each type of connection
    for connection_type in ["energy_connections", "material_connections", "information_connections"]:
        for connection in system[connection_type]:
            G.add_edge(connection[0], connection[1])
    
    # Count incoming and outgoing connections for each node
    connection_counts = {
        node: {
            "incoming": G.in_degree(node),
            "outgoing": G.out_degree(node)
        }
        for node in G.nodes()
    }
    return connection_counts

In [13]:
def find_small_groups_with_connections(system):
    G = nx.Graph()
    
    # Add nodes and edges
    for func in system["functions"]:
        G.add_node(func)
    for connection_type in ["energy_connections", "material_connections", "information_connections"]:
        for connection in system[connection_type]:
            G.add_edge(connection[0], connection[1])
    
    # Identify all connected components
    components = list(nx.connected_components(G))
    
    # Define threshold for "small" groups (fewer than half the nodes)
    total_nodes = len(G.nodes)
    small_groups = [component for component in components if len(component) < total_nodes / 2]
    
    # For each small group, find internal connections
    small_groups_with_connections = {
        frozenset(group): [(u, v) for u, v in G.edges if u in group and v in group]
        for group in small_groups
    }
    
    return small_groups_with_connections

In [14]:
import networkx as nx

def calculate_network_quality(system):
    # Initialize score variables
    base_complexity_score = 0
    error_penalty = 0
    connectedness_penalty = 0

    # Step 1: Base Complexity Score
    # Increase score based on number of functions and connections
    num_functions = len(system["functions"])
    num_energy_connections = len(system["energy_connections"])
    num_material_connections = len(system["material_connections"])
    num_information_connections = len(system["information_connections"])
    
    # Weights for scoring each component
    function_weight = 10
    connection_weight = 5
    
    base_complexity_score = (num_functions * function_weight + 
                             (num_energy_connections + num_material_connections + num_information_connections) * connection_weight)

    # Step 2: Error Penalty
    errors = get_errors(system)
    error_penalty = len(errors) * 100  # Each error costs 50 points

    # Step 3: Connectedness Penalty
    # Use the find_small_groups_with_connections to penalize for isolated groups
    small_groups = find_small_groups_with_connections(system)
    connectedness_penalty = len(small_groups) * 100  # Each small group costs 30 points

    # Final Score Calculation
    quality_score = base_complexity_score - error_penalty - connectedness_penalty

    # Normalize quality score to ensure it's non-negative
    quality_score = max(0, quality_score)
    
    return quality_score


In [15]:
import json
import re
def correct_json_format(input_data):
    # Convert the data to a string and replace newline characters after certain keywords
    data_str = str(input_data)
    
    # Using regex to find instances where a newline follows a keyword and remove it
    corrected_str = re.sub(r'(\w)\s*\n\s*(\()', r'\1 \2', data_str)
    
    # Now parse the corrected string to a valid dictionary
    corrected_data = eval(corrected_str)  # Unsafe with untrusted input; here, it's for a controlled case
    
    return corrected_data

In [58]:
import matplotlib.pyplot as plt
import networkx as nx
import matplotlib.lines as mlines

def draw_decomposition2(decomposition2):
    
    # Create a directed graph
    G = nx.DiGraph()
    
    nodes = list(decomposition2["functions"].keys())
    
    # Add nodes to the graph
    for node in nodes:
        G.add_node(node)
    
    not_input = []
    not_output = []
    
    for i in decomposition2["energy_connections"]:
        not_output.append(i[0])
        not_input.append(i[1])
    for i in decomposition2["information_connections"]:
        not_output.append(i[0])
        not_input.append(i[1])
    for i in decomposition2["material_connections"]:
        not_output.append(i[0])
        not_input.append(i[1])
    
    # Assign colors to nodes based on input (green), output (red), internal (blue)
    node_colors = []
    for node in nodes:
        if node not in not_input:
            node_colors.append('lightgreen')  # Inputs
        elif node not in not_output:
            node_colors.append('lightcoral')  # Outputs
        else:
            node_colors.append('lightblue')  # Intermediate processes
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    # Draw the graph with positions
    pos = decomposition2["functions"]
    nx.draw(G, pos, with_labels=False, node_size=3000, node_color=node_colors, font_size=10, font_weight='bold')
    
    # Draw text labels with rotation
    for node, (x, y) in pos.items():
        plt.text(x, y, s=node, fontsize=10, fontweight='bold', ha='center', va='center', rotation=-90)
    
    # Draw dotted edges for energy movement
    nx.draw_networkx_edges(G, pos, edgelist=decomposition2["energy_connections"], width=2, style='dotted', arrows=True, arrowsize=10, arrowstyle='-|>', min_source_margin=30, min_target_margin=30)
    
    # Draw solid edges for material movement
    nx.draw_networkx_edges(G, pos, edgelist=decomposition2["material_connections"], width=2, style="solid", arrows=True, arrowsize=10, arrowstyle='-|>', min_source_margin=30, min_target_margin=30)
    
    nx.draw_networkx_edges(G, pos, edgelist=decomposition2["information_connections"], width=2, style="dashed", arrows=True, arrowsize=10, arrowstyle='-|>', min_source_margin=30, min_target_margin=30)
    
    # Create custom legend for colors
    input_legend = mlines.Line2D([], [], color='lightgreen', marker='o', markersize=10, linestyle='None', label='Input')
    output_legend = mlines.Line2D([], [], color='lightcoral', marker='o', markersize=10, linestyle='None', label='Output')
    internal_legend = mlines.Line2D([], [], color='lightblue', marker='o', markersize=10, linestyle='None', label='Internal Process')
    
    # Create custom legend for edges
    solid_line = mlines.Line2D([], [], color='black', linestyle='solid', label='Material Flow')
    dotted_line = mlines.Line2D([], [], color='black', linestyle='dotted', label='Energy Flow')
    dashed_line = mlines.Line2D([], [], color='black', linestyle='dashed', label='Information Flow')
    
    # Add the legends to the plot
    plt.legend(handles=[input_legend, output_legend, internal_legend, solid_line, dotted_line, dashed_line], loc='lower right')
    
    # Show plot
    plt.title("Functional Decomposition")
    plt.show()


In [16]:
import matplotlib.pyplot as plt
import networkx as nx
import matplotlib.lines as mlines


def draw_decomposition(decomposition2):
    
    # Create a directed graph
    G = nx.DiGraph()
    
    
    nodes = list(decomposition2["functions"].keys())
    
    # Add nodes to the graph
    for node in nodes:
        G.add_node(node)
    
    not_input = []
    not_output = []
    
    for i in decomposition2["energy_connections"]:
        not_output.append(i[0])
        not_input.append(i[1])
    for i in decomposition2["information_connections"]:
        not_output.append(i[0])
        not_input.append(i[1])
    for i in decomposition2["material_connections"]:
        not_output.append(i[0])
        not_input.append(i[1])

    

    
    
    # Assign colors to nodes based on input (green), output (red), internal (blue)
    node_colors = []
    for node in nodes:
        if node not in not_input:
            node_colors.append('lightgreen')  # Inputs
        elif node not in not_output:
            node_colors.append('lightcoral')  # Outputs
        else:
            node_colors.append('lightblue')  # Intermediate processes
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    # Draw the graph with positions
    nx.draw(G, decomposition2["functions"], with_labels=True, node_size=3000, node_color=node_colors, font_size=10, font_weight='bold')
    
    
    # Draw dotted edges for energy movement
    nx.draw_networkx_edges(G, decomposition2["functions"], edgelist=decomposition2["energy_connections"], width=2, style='dotted', arrows=True, arrowsize=10, arrowstyle='-|>', min_source_margin=30, min_target_margin=30)
    
    # Draw solid edges for material movement
    nx.draw_networkx_edges(G, decomposition2["functions"], edgelist=decomposition2["material_connections"], width=2, style="solid", arrows=True, arrowsize=10, arrowstyle='-|>', min_source_margin=30, min_target_margin=30)
    
    nx.draw_networkx_edges(G, decomposition2["functions"], edgelist=decomposition2["information_connections"], width=2, style="dashed", arrows=True, arrowsize=10, arrowstyle='-|>', min_source_margin=30, min_target_margin=30)
    
    # Create custom legend for colors
    input_legend = mlines.Line2D([], [], color='lightgreen', marker='o', markersize=10, linestyle='None', label='Input')
    output_legend = mlines.Line2D([], [], color='lightcoral', marker='o', markersize=10, linestyle='None', label='Output')
    internal_legend = mlines.Line2D([], [], color='lightblue', marker='o', markersize=10, linestyle='None', label='Internal Process')
    
    # Create custom legend for edges
    solid_line = mlines.Line2D([], [], color='black', linestyle='solid', label='Material Flow')
    dotted_line = mlines.Line2D([], [], color='black', linestyle='dotted', label='Energy Flow')
    dashed_line = mlines.Line2D([], [], color='black', linestyle='dashed', label='Information Flow')
    
    # Add the legends to the plot
    #plt.legend(handles=[input_legend, output_legend, internal_legend, solid_line, dotted_line,dashed_line], loc='lower right')
    
    # Show plot
    plt.title("Functional Decomposition")
    plt.show()