In [1]:
from typing import Annotated, List, Dict, Any, Literal
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_ollama import ChatOllama
import re
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langgraph.graph import StateGraph, END, START

In [2]:
llm = ChatOllama(model="qwen3:4b", temperature=0)
call_stack =[]

In [3]:
class CoTState(TypedDict):
    messages: Annotated[List[Any], "Messages in the conversation"]
    next_step: Literal["decompose", "step_1", "step_2", "step_3", "synthesize", "end"]
    problem: str
    steps: List[str]
    step_results: Dict[str, str]
    final_answer: str


In [4]:
class StripThinkParser(BaseOutputParser[str]):
    """Remove <think>...</think> and surrounding whitespace."""
    def parse(self, text: str) -> str:
        cleaned = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
        call_stack.append(f"StripThinkParser.parse: {text}")
        return cleaned.strip()

In [5]:
def decompose_node(state: CoTState) -> Dict[str, Any]:
    """Break down the problem into sequential reasoning steps"""
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are an expert problem solver. Break down complex problems into 3 clear, logical steps that would help solve it."),
        ("user", "Problem: {problem}\n\nProvide exactly 3 numbered steps (1, 2, 3) to solve this problem. Keep each step concise.")
    ])
    
    chain = prompt | llm | StripThinkParser()
    steps_text = chain.invoke({"problem": state["problem"]})
    
    # Parse steps (assuming they're numbered)
    steps = []
    for line in steps_text.split("\n"):
        if line.strip() and (line[0].isdigit() or "step" in line.lower()):
            steps.append(line.strip())
    
    # Ensure we have exactly 3 steps
    while len(steps) < 3:
        steps.append(f"Step {len(steps)+1}: Additional analysis required")
    
    res = {
        "steps": steps[:3],  # Limit to 3 steps
        "next_step": "step_1"
    }
    call_stack.append(f"decompose_node: {res}")
    return res


In [6]:
def execute_step_node(step_name: str, step_index: int):
    """Factory function to create step execution nodes"""
    def step_node(state: CoTState) -> Dict[str, Any]:
        current_step = state["steps"][step_index-1] if len(state["steps"]) >= step_index else f"Step {step_index}"
        
        # Include previous step results for context
        previous_results = ""
        for i in range(1, step_index):
            step_key = f"step_{i}"
            if step_key in state["step_results"]:
                previous_results += f"\nStep {i} result: {state['step_results'][step_key]}"
        
        prompt = ChatPromptTemplate.from_messages([
            ("system", "You are solving a problem step by step. Provide a clear, logical answer for this specific step."),
            ("user", f"Original Problem: {state['problem']}\n{previous_results}\n\nCurrent Step: {current_step}\n\nProvide a detailed answer for this step:")
        ])
        
        chain = prompt | llm | StripThinkParser()
        step_result = chain.invoke({})
        
        # Update step results
        step_results = state.get("step_results", {}).copy()
        step_results[f"step_{step_index}"] = step_result
        
        # Determine next step
        next_step = f"step_{step_index+1}" if step_index < 3 else "synthesize"
        
        res = {
            "step_results": step_results,
            "next_step": next_step
        }
        call_stack.append(f"execute_step_node.step_node: {res}")
        return res
    
    return step_node

In [7]:
def synthesize_node(state: CoTState) -> Dict[str, Any]:
    """Combine all steps into a coherent final answer"""
    # Prepare the reasoning chain
    reasoning_chain = f"Problem: {state['problem']}\n\n"
    reasoning_chain += "Reasoning Steps:\n"
    
    for i, step in enumerate(state["steps"], 1):
        step_key = f"step_{i}"
        step_result = state["step_results"].get(step_key, "No result available")
        reasoning_chain += f"{step}\nAnswer: {step_result}\n\n"
    
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are synthesizing a final answer based on step-by-step reasoning. Present a clear, concise final answer that incorporates all previous reasoning."),
        ("user", f"{reasoning_chain}\n\nBased on the above reasoning steps, provide a final comprehensive answer to the original problem:")
    ])
    
    chain = prompt | llm | StripThinkParser()
    final_answer = chain.invoke({})
    
    res = {
        "final_answer": final_answer,
        "next_step": "end"
    }
    call_stack.append(f"synthesize_node: {res}")
    return res

In [8]:
def router(state: CoTState) -> str:
    """Determine which node to execute next"""
    res = state["next_step"]
    call_stack.append(f"router: {res}")
    return res

In [9]:
builder = StateGraph(CoTState)

# Add nodes
builder.add_node("decompose", decompose_node)
builder.add_node("step_1", execute_step_node("step_1", 1))
builder.add_node("step_2", execute_step_node("step_2", 2))
builder.add_node("step_3", execute_step_node("step_3", 3))
builder.add_node("synthesize", synthesize_node)

# Add edges
builder.add_edge(START, "decompose")
builder.add_edge("decompose", "step_1")
builder.add_edge("step_1", "step_2")
builder.add_edge("step_2", "step_3")
builder.add_edge("step_3", "synthesize")
builder.add_conditional_edges(
    "synthesize",
    router,
    {
        "end": END
    }
)

# Compile the graph
cot_graph = builder.compile()

In [10]:
def run_chain_of_thought(problem: str) -> Dict[str, Any]:
    """Run the Chain of Thought system with a problem"""
    initial_state = CoTState(
        messages=[HumanMessage(content=problem)],
        next_step="decompose",
        problem=problem,
        steps=[],
        step_results={},
        final_answer=""
    )
    
    result = cot_graph.invoke(initial_state)
    
    # Format the output for clarity
    output = {
        "problem": problem,
        "decomposed_steps": result["steps"],
        "step_by_step_reasoning": result["step_results"],
        "final_answer": result["final_answer"]
    }

    return output

In [11]:
math_problem = "A train leaves station A at 60 mph and another leaves station B at 40 mph towards A. If they are 300 miles apart, when do they meet?"

try:
    result1 = run_chain_of_thought(math_problem)
except Exception as e:
    print("math_problem:", call_stack)
    raise e

print("math_problem: ", call_stack)

print("=== Chain of Thought Analysis ===")
print(f"Problem: {result1['problem']}\n")

print("Decomposed Steps:")
for i, step in enumerate(result1['decomposed_steps'], 1):
    print(f"  {i}. {step}")

print("\nStep-by-Step Reasoning:")
for step, answer in result1['step_by_step_reasoning'].items():
    print(f"  {step.replace('_', ' ').title()}: {answer}")

print(f"\nFinal Answer: {result1['final_answer']}")

print("\n" + "="*50 + "\n")

# Example 2: Logic problem
logic_problem = "If all Bloops are Razzies and all Razzies are Loppies, are all Bloops definitely Loppies?"

try:
    result2 = run_chain_of_thought(logic_problem)
except Exception as e:
    print("logic_problem: ", call_stack)
    raise e

print("logic_problem: ", call_stack)

print("=== Chain of Thought Analysis ===")
print(f"Problem: {result2['problem']}\n")

print("Decomposed Steps:")
for i, step in enumerate(result2['decomposed_steps'], 1):
    print(f"  {i}. {step}")

print("\nStep-by-Step Reasoning:")
for step, answer in result2['step_by_step_reasoning'].items():
    print(f"  {step.replace('_', ' ').title()}: {answer}")

print(f"\nFinal Answer: {result2['final_answer']}")

math_problem:  ['StripThinkParser.parse: <think>\nOkay, let\'s see. The problem is about two trains leaving stations A and B towards each other. Station A has a train going at 60 mph, and station B has another going at 40 mph. They\'re 300 miles apart. I need to figure out when they meet.\n\nFirst, I should probably figure out how fast they\'re approaching each other. Since they\'re moving towards each other, their speeds add up. So 60 plus 40 is 100 mph. That makes sense because when two objects move towards each other, their relative speed is the sum of their individual speeds.\n\nNext, if they\'re closing the distance at 100 mph and the initial distance is 300 miles, the time it takes for them to meet would be the total distance divided by their combined speed. So 300 divided by 100 is 3 hours. That seems straightforward.\n\nWait, but maybe I should check if there\'s any other factor. Like, do they start at the same time? The problem says they leave their respective stations, so I a