In [73]:
from langchain.prompts import PromptTemplate
from langchain_ollama.llms import OllamaLLM
from langchain.schema import BaseOutputParser
from langchain.chains.base import Chain
import json

class CustomCoTParser(BaseOutputParser):
    """
    A robust output parser for handling Tree of Thought (ToT) reasoning steps.
    """
    def clean_response(self, output: str) -> str:
        """Sanitize and clean the raw response to ensure valid JSON formatting."""
        output = output.strip()  # Remove leading and trailing whitespace
        if "%" in output:
            output = output.replace("%", "")  # Remove percentage signs
        if output.startswith("{") and output.endswith("}"):
            output = output.replace("\n", " ").replace("\\n", " ")  # Replace newlines
        return output

    def validate_response(self, parsed_output: dict) -> dict:
        """Ensure the parsed response has the required fields."""
        required_keys = {"title", "content", "next_action"}
        for key in required_keys:
            if key not in parsed_output:
                raise ValueError(f"Missing required field: {key}")
        return parsed_output

    def parse(self, output: str):
        """Parse and validate the output."""
        try:
            cleaned_output = self.clean_response(output)
            parsed_output = json.loads(cleaned_output)
            return self.validate_response(parsed_output)
        except json.JSONDecodeError:
            print(f"Error decoding JSON: {output}")
            return {
                "title": "Fallback Reasoning",
                "content": "Unable to parse a valid response. Please clarify.",
                "confidence": 50,
                "next_action": "continue",
            }
        except ValueError as ve:
            print(f"Validation error: {ve}")
            return {
                "title": "Fallback Reasoning",
                "content": str(ve),
                "confidence": 50,
                "next_action": "continue",
            }

# Define the system prompt to enforce Tree of Thought reasoning
system_prompt = '''
You are an expert AI assistant that creates advanced reasoning trees. For each branch:
1. Clearly progress towards exploring alternative solutions.
2. Avoid repeating the same content across branches.
3. Use diverse approaches to reason through sub-problems.
4. Stop when an optimal or final answer is reached or after a maximum of 10 steps.
5. Ensure each step's "next_action" field logically follows from the step's "content".

Example Query: "How many R's are in the word 'strawberry'?"
Example Response:
{{
  "title": "Counting R's in 'strawberry'",
  "content": "The word 'strawberry' contains the letters S-T-R-A-W-B-E-R-R-Y. Counting the R's, we find there are 3 R's in this word.",
  "confidence": 95,
  "next_action": "final_answer"
}}

Example Query: "Write a Hello World program in Java."
Example Response:
{{
  "title": "Java Hello World Program",
  "content": "Here is the code:\npublic class HelloWorld {{\n  public static void main(String[] args) {{\n    System.out.println(\"Hello, World!\");\n  }}\n}}",
  "confidence": 95,
  "next_action": "final_answer"
}}

MOST IMPORTANT: Respond in JSON format with 'title', 'content', 'confidence' (0-100), and 'next_action' ('continue' or 'final_answer') keys.
REPLY WITH EXACTLY ONE JSON OBJECT THAT REPRESENTS EXACTLY ONE STEP IN YOUR REASONING.

Example of a valid JSON response:
{{
    "title": "Initial Problem Analysis",
    "content": "To begin solving this problem, I'll break it down into its core components...",
    "confidence": 90,
    "next_action": "continue"
}}
'''

# Define the template and chain for enforcing CoT reasoning
prompt_template = PromptTemplate(
    input_variables=["input_query"],
    template=system_prompt + "\nUser Query: {input_query}"
)

local_llm = OllamaLLM(model="codeup", temperature=0.2)

# Updated chain configuration using pipe separator
from langchain.chains import SequentialChain

# Updated chain configuration using ToT
class ToTChain(Chain):
    """Custom chain for handling Tree of Thought responses."""
    prompt: PromptTemplate
    llm: OllamaLLM

    @property
    def input_keys(self):
        return ["input_query"]

    @property
    def output_keys(self):
        return ["output"]

    def _call(self, inputs: dict):
        formatted_prompt = self.prompt.format(input_query=inputs["input_query"])
        response = self.llm.invoke(formatted_prompt)
        return {"output": response}

    def invoke(self, inputs: dict):
        return self._call(inputs)

# Iterative Chain of Thought Reasoning
cot_chain = ToTChain(prompt=prompt_template, llm=local_llm)

def run_tot_chain(query, max_steps=10):
    branches = []
    current_query = query
    seen_responses = set()

    for _ in range(max_steps):
        result = cot_chain.invoke({"input_query": current_query})
        try:
            # Attempt to clean and parse the response
            raw_output = result["output"].strip()  # Remove whitespace
            if "%" in raw_output:
                raw_output = raw_output.replace("%", "")  # Remove percentage signs
            if raw_output.startswith("{") and raw_output.endswith("}"):
                raw_output = raw_output.replace("\n", " ").replace("\\n", " ")  # Remove newlines
            parsed_result = json.loads(raw_output)

            # Append the parsed step to the reasoning tree
            branches.append(parsed_result)

            # Check for the final answer
            if parsed_result.get("next_action") == "final_answer":
                break

            # Use high-confidence branches to guide further exploration
            if parsed_result.get("confidence", 0) > 80:
                current_query = parsed_result["content"] + " What is the next step?"

            # Detect repetition and avoid unnecessary loops
            if parsed_result["content"] in seen_responses:
                current_query = f"Please provide the final answer to: {query}"
            else:
                seen_responses.add(parsed_result["content"])
                current_query = parsed_result["content"]
            
            print(parsed_result["title"])
        except json.JSONDecodeError:
            # Handle malformed JSON fallback
            print("Error parsing response:", result)
            fallback_content = f"Unable to parse response for query: {query}. Retrying with clarification."
            branches.append({
                "title": "Fallback Reasoning",
                "content": fallback_content,
                "confidence": 50,
                "next_action": "continue",
            })
            current_query = f"Please clarify and provide a valid response to: {query}"

    # Aggregate final results from high-confidence branches
    final_results = [branch for branch in branches if branch.get("confidence", 0) > 80]
    return final_results if final_results else branches

# Test the chain
query = "what is the most effective way of reading a book in order to retain as much info as possible?"
response_branches = run_tot_chain(query)
for branch in response_branches:
    print(json.dumps(branch, indent=2))

Error parsing response: {'output': '\n{\n  "title": "Effective Book Reading Strategies",\n  "content": "To maximize information retention while reading a book, it\'s essential to engage your brain and maintain focus. Here are some strategies that can help:\n\n1. Active reading: Instead of just passively reading the text, actively engage with it by asking questions, making connections, and summarizing the content as you go along. This will help you retain more information and make the material more meaningful.\n2. Set goals: Before starting to read, set specific goals for what you want to achieve from the book. This could be to learn a new skill, understand a particular concept, or simply enjoy the story. Having clear goals will help you stay focused and motivated.\n3. Use visualization techniques: Visualize the information you\'re reading to make it more memorable. For example, if you\'re reading about a historical event, try to picture the setting, characters, and key moments in your 