In [59]:
from langchain.prompts import PromptTemplate
from langchain_ollama.llms import OllamaLLM
from langchain.schema import BaseOutputParser
from langchain.chains.base import Chain
import json

class CustomCoTParser(BaseOutputParser):
    """
    A custom output parser for forcing Chain of Thought (CoT) reasoning steps.
    """
    def parse(self, output: str):
        # Parse JSON-formatted CoT response
        try:
            parsed_output = json.loads(output)
            return parsed_output
        except json.JSONDecodeError:
            # Fallback to partial extraction of content and next action if JSON fails
            if "content" in output and "next_action" in output:
                return {
                    "title": "Fallback Reasoning",
                    "content": output.split("\ncontent: ")[1].split("\n")[0],
                    "confidence": 50,
                    "next_action": "continue"
                }
            raise ValueError("Invalid response format. Expected JSON or partial structure.")

# Define the system prompt to enforce Tree of Thought reasoning
system_prompt = '''
You are an expert AI assistant that creates advanced reasoning trees. For each branch:
1. Clearly progress towards exploring alternative solutions.
2. Avoid repeating the same content across branches.
3. Use diverse approaches to reason through sub-problems.
4. Stop when an optimal or final answer is reached or after a maximum of 10 steps.
5. Ensure each step's "next_action" field logically follows from the step's "content".

Example Query: "How many R's are in the word 'strawberry'?"
Example Response:
{{
  "title": "Counting R's in 'strawberry'",
  "content": "The word 'strawberry' contains the letters S-T-R-A-W-B-E-R-R-Y. Counting the R's, we find there are 3 R's in this word.",
  "confidence": 95,
  "next_action": "final_answer"
}}

Example Query: "Write a Hello World program in Java."
Example Response:
{{
  "title": "Java Hello World Program",
  "content": "Here is the code:\npublic class HelloWorld {{\n  public static void main(String[] args) {{\n    System.out.println(\"Hello, World!\");\n  }}\n}}",
  "confidence": 95,
  "next_action": "final_answer"
}}

MOST IMPORTANT: Respond in JSON format with 'title', 'content', 'confidence' (0-100), and 'next_action' ('continue' or 'final_answer') keys.
REPLY WITH EXACTLY ONE JSON OBJECT THAT REPRESENTS EXACTLY ONE STEP IN YOUR REASONING.

Example of a valid JSON response:
{{
    "title": "Initial Problem Analysis",
    "content": "To begin solving this problem, I'll break it down into its core components...",
    "confidence": 90,
    "next_action": "continue"
}}
'''

# Define the template and chain for enforcing CoT reasoning
prompt_template = PromptTemplate(
    input_variables=["input_query"],
    template=system_prompt + "\nUser Query: {input_query}"
)

local_llm = OllamaLLM(model="codeup", temperature=0.2)

# Updated chain configuration using pipe separator
from langchain.chains import SequentialChain

# Updated chain configuration using ToT
class ToTChain(Chain):
    """Custom chain for handling Tree of Thought responses."""
    prompt: PromptTemplate
    llm: OllamaLLM

    @property
    def input_keys(self):
        return ["input_query"]

    @property
    def output_keys(self):
        return ["output"]

    def _call(self, inputs: dict):
        formatted_prompt = self.prompt.format(input_query=inputs["input_query"])
        response = self.llm.invoke(formatted_prompt)
        return {"output": response}

    def invoke(self, inputs: dict):
        return self._call(inputs)

# Iterative Chain of Thought Reasoning
cot_chain = ToTChain(prompt=prompt_template, llm=local_llm)

def run_tot_chain(query, max_steps=10):
    branches = []
    current_query = query
    seen_responses = set()

    for _ in range(max_steps):
        result = cot_chain.invoke({"input_query": current_query})
        try:
            # Attempt to clean and parse the response
            raw_output = result["output"]
            if "%" in raw_output:
                raw_output = raw_output.replace("%", "")  # Remove percentage signs
            parsed_result = json.loads(raw_output)

            # Append the parsed step to the reasoning tree
            branches.append(parsed_result)

            # Use high-confidence branches to guide further exploration
            if parsed_result.get("confidence", 0) > 80:
                current_query = parsed_result["content"] + " What is the next step?"

            # Check for the final answer
            if parsed_result.get("next_action") == "final_answer":
                break

            # Detect repetition and avoid unnecessary loops
            if parsed_result["content"] in seen_responses:
                current_query = "Please provide the final answer based on your reasoning above."
            else:
                seen_responses.add(parsed_result["content"])
                current_query = parsed_result["content"]
            print(parsed_result["title"])
        except json.JSONDecodeError:
            # Handle malformed JSON fallback
            print("Error parsing response:", result)
            fallback_content = "Fallback: Unable to parse response."
            if "content" in result["output"]:
                fallback_content = result["output"]
            branches.append({
                "title": "Fallback Reasoning",
                "content": fallback_content,
                "confidence": 50,
                "next_action": "continue",
            })
            current_query = "Please provide the final answer based on your reasoning above."

    # Aggregate final results from high-confidence branches
    final_results = [branch for branch in branches if branch.get("confidence", 0) > 80]
    return final_results if final_results else branches

# Test the chain
query = "write hello world in python"
response_branches = run_tot_chain(query)
for branch in response_branches:
    print(json.dumps(branch, indent=2))

Error parsing response: {'output': '\n{\n  "title": "Python Hello World Program",\n  "content": "Here is the code:\nprint(\\"Hello, World!\\")\n",\n  "confidence": 95,\n  "next_action": "final_answer"\n}'}
Error parsing response: {'output': '\nPlease provide the final answer based on your reasoning above.'}
Error parsing response: {'output': '\nPlease find my response below:\n\n{\n  "title": "Final Answer",\n  "content": "The word \'strawberry\' contains 3 R\'s.",\n  "confidence": 95,\n  "next_action": "final_answer"\n}'}
Error parsing response: {'output': '\nPlease provide the final answer based on your reasoning above.'}
{
  "title": "Final Answer",
  "content": "The word 'strawberry' contains 3 R's.",
  "confidence": 95,
  "next_action": "final_answer"
}
