In [93]:
from langchain.prompts import PromptTemplate
from langchain_ollama.llms import OllamaLLM
from langchain.schema import BaseOutputParser
from langchain.chains.base import Chain
import json

class CustomCoTParser(BaseOutputParser):
    """
    A robust output parser for handling Tree of Thought (ToT) reasoning steps.
    """
    def clean_response(self, output: str) -> str:
        """Sanitize and clean the raw response to ensure valid JSON formatting."""
        output = output.strip()  # Remove leading and trailing whitespace
        if "%" in output:
            output = output.replace("%", "")  # Remove percentage signs
        if output.startswith("{") and output.endswith("}"):
            output = output.replace("\n", "").replace("\\n", "")  # Replace newlines
        return output

    def validate_response(self, parsed_output: dict) -> dict:
        """Ensure the parsed response has the required fields."""
        required_keys = {"title", "content", "next_action"}
        for key in required_keys:
            if key not in parsed_output:
                raise ValueError(f"Missing required field: {key}")
        return parsed_output

    def parse_old(self, output: str):
        """Parse and validate the output."""
        try:
            cleaned_output = self.clean_response(output)
            parsed_output = json.loads(cleaned_output)
            return self.validate_response(parsed_output)
        except json.JSONDecodeError:
            print(f"Error decoding JSON: {output}")
            return {
                "title": "Fallback Reasoning",
                "content": "Unable to parse a valid response. Please clarify.",
                "confidence": 50,
                "next_action": "continue",
            }
        except ValueError as ve:
            print(f"Validation error: {ve}")
            return {
                "title": "Fallback Reasoning",
                "content": str(ve),
                "confidence": 50,
                "next_action": "continue",
            }
    def parse(self, output: str):
        try:
            # Clean the output by stripping unnecessary whitespace and ensuring proper JSON formatting
            cleaned_output = output.strip()
            if cleaned_output.startswith("{") and cleaned_output.endswith("}"):
                cleaned_output = cleaned_output.replace("\n", " ")  # Replace newlines with spaces for JSON compatibility

            # Parse the cleaned JSON string
            parsed_output = json.loads(cleaned_output)

            # Validate the required keys
            required_keys = {"title", "content", "next_action"}
            if not required_keys.issubset(parsed_output):
                raise ValueError(f"Response is missing required keys: {required_keys - parsed_output.keys()}")

            return parsed_output

        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {output}. Error: {e}")
            return {
                "title": "Fallback Reasoning",
                "content": "Unable to parse valid JSON. Please clarify your response.",
                "confidence": 50,
                "next_action": "continue",
            }
        except Exception as e:
            print(f"Error in parsing response: {output}. Error: {e}")
            return {
                "title": "Fallback Reasoning",
                "content": "An unexpected error occurred while processing the response.",
                "confidence": 50,
                "next_action": "continue",
            }

# Define the system prompt to enforce Tree of Thought reasoning
system_prompt = '''
You are an expert AI assistant that creates advanced reasoning trees. For each branch:
1. Clearly progress towards exploring alternative solutions.
2. Avoid repeating the same content across branches.
3. Use diverse approaches to reason through sub-problems.
4. Stop when an optimal or final answer is reached or after a maximum of 10 steps.
5. Ensure each step's "next_action" field logically follows from the step's "content".

Example Query: "How many R's are in the word 'strawberry'?"
Example Response:
{{
  "title": "Counting R's in 'strawberry'",
  "content": "The word 'strawberry' contains the letters S-T-R-A-W-B-E-R-R-Y. Counting the R's, we find there are 3 R's in this word.",
  "confidence": 95,
  "next_action": "final_answer"
}}

Example Query: "Write a Hello World program in Java."
Example Response:
{{
  "title": "Java Hello World Program",
  "content": "Here is the code:\npublic class HelloWorld {{\n  public static void main(String[] args) {{\n    System.out.println(\"Hello, World!\");\n  }}\n}}",
  "confidence": 95,
  "next_action": "final_answer"
}}

MOST IMPORTANT: Respond in JSON format with 'title', 'content', 'confidence' (0-100), and 'next_action' ('continue' or 'final_answer') keys.
REPLY WITH EXACTLY ONE JSON OBJECT THAT REPRESENTS EXACTLY ONE STEP IN YOUR REASONING.

Example of a valid JSON response:
{{
    "title": "Initial Problem Analysis",
    "content": "To begin solving this problem, I'll break it down into its core components...",
    "confidence": 90,
    "next_action": "continue"
}}
'''

# Define the template and chain for enforcing CoT reasoning
prompt_template = PromptTemplate(
    input_variables=["input_query"],
    template=system_prompt + "\nUser Query: {input_query}"
)

local_llm = OllamaLLM(model="qwq", temperature=0.2)

# Updated chain configuration using pipe separator
from langchain.chains import SequentialChain

# Updated chain configuration using ToT
class ToTChain(Chain):
    """Custom chain for handling Tree of Thought responses."""
    prompt: PromptTemplate
    llm: OllamaLLM

    @property
    def input_keys(self):
        return ["input_query"]

    @property
    def output_keys(self):
        return ["output"]

    def _call(self, inputs: dict):
        formatted_prompt = self.prompt.format(input_query=inputs["input_query"])
        response = self.llm.invoke(formatted_prompt)
        return {"output": response}

    def invoke(self, inputs: dict):
        return self._call(inputs)

# Iterative Chain of Thought Reasoning
cot_chain = ToTChain(prompt=prompt_template, llm=local_llm)

def run_tot_chain(query, max_steps=10):
    branches = []
    current_query = query
    seen_responses = set()
    context = []  # To store the reasoning context

    for step in range(max_steps):
        # Incorporate accumulated context into the query
        if context:
            context_string = " ".join([branch["content"] for branch in context[-3:]])  # Use the last 3 steps for brevity
            query_with_context = f"{context_string} Based on this, {current_query}"
        else:
            query_with_context = current_query

        result = cot_chain.invoke({"input_query": query_with_context})
        try:
            # Parse the model response
            parser = CustomCoTParser()
            parsed_result = parser.parse(result["output"])

            # Append the parsed step to the reasoning tree
            branches.append(parsed_result)

            # Add high-confidence branches to the context
            if parsed_result.get("confidence", 0) > 80:
                context.append(parsed_result)

            # Check for the final answer
            if parsed_result.get("next_action") == "final_answer":
                break

            # Avoid loops by detecting repetition
            if parsed_result["content"] in seen_responses:
                current_query = "Please refine and provide the final answer based on your reasoning above."
            else:
                seen_responses.add(parsed_result["content"])
                current_query = parsed_result["content"]
            print(parsed_result["title"])
        except Exception as e:
            # Handle unexpected errors gracefully
            print(f"Error during processing: {e}")
            fallback_content = f"Fallback: Unable to process response. Error: {e}"
            branches.append({
                "title": "Fallback Reasoning",
                "content": fallback_content,
                "confidence": 50,
                "next_action": "continue",
            })
            current_query = "Please refine and provide the final answer based on your reasoning above."

    # Aggregate final results from high-confidence branches
    final_results = [branch for branch in branches if branch.get("confidence", 0) > 80]

    # Synthesize branches into a comprehensive final answer if possible
    if len(final_results) > 1:
        combined_content = " ".join([branch["content"] for branch in final_results])
        final_results.append({
            "title": "Comprehensive Final Answer",
            "content": combined_content,
            "confidence": 100,
            "next_action": "done",
        })

    return final_results if final_results else branches

# Test the chain
query = "teach me how to learn"
response_branches = run_tot_chain(query)
for branch in response_branches:
    print(json.dumps(branch, indent=2))

Introduction to Learning Strategies
Identifying Dominant Learning Style
Determining Dominant Learning Style
Determining Dominant Learning Style
Determining Dominant Learning Style
Initial Reflection on Learning Styles
Initial Reflection on Learning Styles
Initial Reflection on Learning Styles
Initial Reflection on Learning Styles
Initial Reflection on Learning Styles
{
  "title": "Introduction to Learning Strategies",
  "content": "To effectively learn, it's essential to understand that everyone has a unique learning style, which can be visual, auditory, or kinesthetic. The first step is to identify your dominant learning style and then explore various techniques such as setting clear goals, creating a conducive learning environment, and utilizing active recall methods.",
  "confidence": 92,
  "next_action": "continue"
}
{
  "title": "Identifying Dominant Learning Style",
  "content": "To tailor an effective learning approach, it's crucial to first determine whether one's dominant lear