In [22]:
from langchain.prompts import PromptTemplate
from langchain_ollama.llms import OllamaLLM
from langchain.schema import BaseOutputParser
from langchain.chains.base import Chain
import json

class CustomCoTParser(BaseOutputParser):
    """
    A robust output parser for handling Tree of Thought (ToT) reasoning steps.
    """
    def clean_response(self, output: str) -> str:
        """Sanitize and clean the raw response to ensure valid JSON formatting."""
        output = output.strip()  # Remove leading and trailing whitespace
        if "%" in output:
            output = output.replace("%", "")  # Remove percentage signs
        if output.startswith("{") and output.endswith("}"):
            output = output.replace("\n", "").replace("\\n", "")  # Replace newlines
        return output

    def validate_response(self, parsed_output: dict) -> dict:
        """Ensure the parsed response has the required fields."""
        required_keys = {"title", "content", "next_action"}
        for key in required_keys:
            if key not in parsed_output:
                raise ValueError(f"Missing required field: {key}")
        return parsed_output

    def parse_old(self, output: str):
        """Parse and validate the output."""
        try:
            cleaned_output = self.clean_response(output)
            parsed_output = json.loads(cleaned_output)
            return self.validate_response(parsed_output)
        except json.JSONDecodeError:
            print(f"Error decoding JSON: {output}")
            return {
                "title": "Fallback Reasoning",
                "content": "Unable to parse a valid response. Please clarify.",
                "confidence": 50,
                "next_action": "continue",
            }
        except ValueError as ve:
            print(f"Validation error: {ve}")
            return {
                "title": "Fallback Reasoning",
                "content": str(ve),
                "confidence": 50,
                "next_action": "continue",
            }
    def parse(self, output: str):
        try:
            # Clean the output by stripping unnecessary whitespace and ensuring proper JSON formatting
            cleaned_output = output.strip()
            if cleaned_output.startswith("{") and cleaned_output.endswith("}"):
                cleaned_output = cleaned_output.replace("\n", " ")  # Replace newlines with spaces for JSON compatibility

            # Parse the cleaned JSON string
            parsed_output = json.loads(cleaned_output)

            # Validate the required keys
            required_keys = {"title", "content", "next_action"}
            if not required_keys.issubset(parsed_output):
                raise ValueError(f"Response is missing required keys: {required_keys - parsed_output.keys()}")

            return parsed_output

        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {output}. Error: {e}")
            return {
                "title": "Fallback Reasoning",
                "content": "Unable to parse valid JSON. Please clarify your response.",
                "confidence": 50,
                "next_action": "continue",
            }
        except Exception as e:
            print(f"Error in parsing response: {output}. Error: {e}")
            return {
                "title": "Fallback Reasoning",
                "content": "An unexpected error occurred while processing the response.",
                "confidence": 50,
                "next_action": "continue",
            }

# Define the system prompt to enforce Tree of Thought reasoning
system_prompt = '''
You are an expert AI assistant that creates advanced reasoning trees. For each branch:
1. Clearly progress towards exploring alternative solutions.
2. Avoid repeating the same content across branches.
3. Use diverse approaches to reason through sub-problems.
4. Stop when an optimal or final answer is reached or after a maximum of 10 steps.
5. Ensure each step's "next_action" field logically follows from the step's "content".

MOST IMPORTANT: Respond in JSON format with 'title', 'content', 'confidence' (0-100), and 'next_action' ('continue' or 'final_answer') keys.
REPLY WITH EXACTLY ONE JSON OBJECT THAT REPRESENTS EXACTLY ONE STEP IN YOUR REASONING.

Example of a valid JSON response:
{{
    "title": "Initial Problem Analysis",
    "content": "To begin solving this problem, I'll break it down into its core components...",
    "confidence": 90,
    "next_action": "continue"
}}
'''

# Define the template and chain for enforcing CoT reasoning
prompt_template = PromptTemplate(
    input_variables=["input_query"],
    template=system_prompt + "\nUser Query: {input_query}"
)

local_llm = OllamaLLM(model="qwen2.5:32b", temperature=0.2)

# Updated chain configuration using pipe separator
from langchain.chains import SequentialChain

# Updated chain configuration using ToT
class ToTChain(Chain):
    """Custom chain for handling Tree of Thought responses."""
    prompt: PromptTemplate
    llm: OllamaLLM

    @property
    def input_keys(self):
        return ["input_query"]

    @property
    def output_keys(self):
        return ["output"]

    def _call(self, inputs: dict):
        formatted_prompt = self.prompt.format(input_query=inputs["input_query"])
        response = self.llm.invoke(formatted_prompt)
        return {"output": response}

    def invoke(self, inputs: dict):
        return self._call(inputs)

# Iterative Chain of Thought Reasoning
cot_chain = ToTChain(prompt=prompt_template, llm=local_llm)

def run_tot_chain(query, max_steps=10):
    branches = []
    current_query = query
    seen_responses = set()
    context = []  # To store the reasoning context

    for step in range(max_steps):
        # Incorporate accumulated context into the query
        if context:
            context_string = " ".join([branch["content"] for branch in context[-3:]])  # Use the last 3 steps for brevity
            query_with_context = f"{context_string} Based on this, {current_query}"
        else:
            query_with_context = current_query

        result = cot_chain.invoke({"input_query": query_with_context})
        try:
            # Parse the model response
            parser = CustomCoTParser()
            parsed_result = parser.parse(result["output"])

            # Append the parsed step to the reasoning tree
            branches.append(parsed_result)

            # Add high-confidence branches to the context
            if parsed_result.get("confidence", 0) > 80:
                context.append(parsed_result)

            # Check for the final answer
            if parsed_result.get("next_action") == "final_answer":
                break

            # Avoid loops by detecting repetition
            if parsed_result["content"] in seen_responses:
                current_query = "Please refine and provide the final answer based on your reasoning above."
            else:
                seen_responses.add(parsed_result["content"])
                current_query = parsed_result["content"]
            print(parsed_result["title"])
        except Exception as e:
            # Handle unexpected errors gracefully
            print(f"Error during processing: {e}")
            fallback_content = f"Fallback: Unable to process response. Error: {e}"
            branches.append({
                "title": "Fallback Reasoning",
                "content": fallback_content,
                "confidence": 50,
                "next_action": "continue",
            })
            current_query = "Please refine and provide the final answer based on your reasoning above."

    # Aggregate final results from high-confidence branches
    final_results = [branch for branch in branches if branch.get("confidence", 0) > 80]

    # Synthesize branches into a comprehensive final answer if possible
    if len(final_results) > 1:
        combined_content = " ".join([branch["content"] for branch in final_results])
        final_results.append({
            "title": "Comprehensive Final Answer",
            "content": combined_content,
            "confidence": 100,
            "next_action": "done",
        })

    return final_results if final_results else branches

# Test the chain
query = '''
write the Python code to solve the following problem:
Satisfied with their search on Ceres, the squadron of scholars suggests subsequently scanning the stationery stacks of sub-basement 17.

The North Pole printing department is busier than ever this close to Christmas, and while The Historians continue their search of this historically significant facility, an Elf operating a very familiar printer beckons you over.

The Elf must recognize you, because they waste no time explaining that the new sleigh launch safety manual updates won't print correctly. Failure to update the safety manuals would be dire indeed, so you offer your services.

Safety protocols clearly indicate that new pages for the safety manuals must be printed in a very specific order. The notation X|Y means that if both page number X and page number Y are to be produced as part of an update, page number X must be printed at some point before page number Y.

The Elf has for you both the page ordering rules and the pages to produce in each update (your puzzle input), but can't figure out whether each update has the pages in the right order.

For example:

47|53
97|13
97|61
97|47
75|29
61|13
75|53
29|13
97|29
53|29
61|53
97|53
61|29
47|13
75|47
97|75
47|61
75|61
47|29
75|13
53|13

75,47,61,53,29
97,61,53,29,13
75,29,13
75,97,47,61,53
61,13,29
97,13,75,29,47
The first section specifies the page ordering rules, one per line. The first rule, 47|53, means that if an update includes both page number 47 and page number 53, then page number 47 must be printed at some point before page number 53. (47 doesn't necessarily need to be immediately before 53; other pages are allowed to be between them.)

The second section specifies the page numbers of each update. Because most safety manuals are different, the pages needed in the updates are different too. The first update, 75,47,61,53,29, means that the update consists of page numbers 75, 47, 61, 53, and 29.

To get the printers going as soon as possible, start by identifying which updates are already in the right order.

In the above example, the first update (75,47,61,53,29) is in the right order:

75 is correctly first because there are rules that put each other page after it: 75|47, 75|61, 75|53, and 75|29.
47 is correctly second because 75 must be before it (75|47) and every other page must be after it according to 47|61, 47|53, and 47|29.
61 is correctly in the middle because 75 and 47 are before it (75|61 and 47|61) and 53 and 29 are after it (61|53 and 61|29).
53 is correctly fourth because it is before page number 29 (53|29).
29 is the only page left and so is correctly last.
Because the first update does not include some page numbers, the ordering rules involving those missing page numbers are ignored.

The second and third updates are also in the correct order according to the rules. Like the first update, they also do not include every page number, and so only some of the ordering rules apply - within each update, the ordering rules that involve missing page numbers are not used.

The fourth update, 75,97,47,61,53, is not in the correct order: it would print 75 before 97, which violates the rule 97|75.

The fifth update, 61,13,29, is also not in the correct order, since it breaks the rule 29|13.

The last update, 97,13,75,29,47, is not in the correct order due to breaking several rules.

For some reason, the Elves also need to know the middle page number of each update being printed. Because you are currently only printing the correctly-ordered updates, you will need to find the middle page number of each correctly-ordered update. In the above example, the correctly-ordered updates are:

75,47,61,53,29
97,61,53,29,13
75,29,13
These have middle page numbers of 61, 53, and 29 respectively. Adding these page numbers together gives 143.

Of course, you'll need to be careful: the actual list of page ordering rules is bigger and more complicated than the above example.

Determine which updates are already in the correct order. What do you get if you add up the middle page number from those correctly-ordered updates?
'''
response_branches = run_tot_chain(query)
for branch in response_branches:
    print(json.dumps(branch, indent=2))

Problem Breakdown
Define Problem Components
Define Core Concepts
Define Core Concepts
Define Core Concepts
Define Core Concepts
{
  "title": "Problem Breakdown",
  "content": "The problem involves parsing a set of rules and update sequences to determine if each sequence adheres to the given rules. The goal is to sum the middle elements of all valid sequences.",
  "confidence": 95,
  "next_action": "continue"
}
{
  "title": "Define Problem Components",
  "content": "Firstly, we need to clearly define what constitutes a rule, an update sequence, and how to identify if a sequence is valid based on these rules. Once defined, we can proceed to analyze each sequence for validity.",
  "confidence": 85,
  "next_action": "continue"
}
{
  "title": "Define Core Concepts",
  "content": "To start solving the problem, I need to define what constitutes a rule, an update sequence, and how to identify if a sequence is valid based on these rules. A 'rule' will be defined as a set of conditions that sequ