In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
from langchain_anthropic.chat_models import ChatAnthropic

In [None]:
llm = ChatAnthropic(model_name="claude-3-5-haiku-20241022")
llm_reasoning = ChatAnthropic(
    model_name="claude-3-7-sonnet-latest",
    model_kwargs={
        "max_tokens": 20000,
        "thinking": {"type": "enabled", "budget_tokens": 1024},
    },
)

In [None]:
result = llm_reasoning.invoke("What is the root of 12")
result

In [None]:
result.content[0]["thinking"]

In [None]:
result.content[1]["text"]

In [None]:
llm_classifier = ChatAnthropic(model_name="claude-3-5-haiku-20241022")

llm_regular = ChatAnthropic(
    model_name="claude-3-7-sonnet-latest",
    model_kwargs={"max_tokens": 3000, "thinking": {"type": "disabled"}},
)

llm_thinking = ChatAnthropic(
    model_name="claude-3-7-sonnet-latest",
    model_kwargs={
        "max_tokens_to_sample": 20000,
        "thinking": {"type": "enabled", "budget_tokens": 1024},
    },
)

In [None]:
from langgraph.graph import END, StateGraph
from typing import Literal
from typing import TypedDict
from pydantic import BaseModel, Field
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate

In [None]:
class DifficultyGrade(BaseModel):
    """Model for capturing difficulty classification."""

    difficulty: str = Field(description="One of: easy, mid, hard, or very hard.")


class DifficultyState(TypedDict):
    messages: list[BaseMessage]
    difficulty: str

In [None]:
def classify_difficulty(state: DifficultyState) -> DifficultyState:
    question = state["messages"][-1].content

    system_prompt = (
        "You are a difficulty classifier. "
        "Classify the user question into exactly one of these categories: "
        "easy, mid, hard, or very hard. Return ONLY the single word: "
        "easy, mid, hard, or very hard."
    )
    human_prompt = f"User question: {question}"

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            ("human", human_prompt),
        ]
    )

    structured_llm = llm_classifier.with_structured_output(DifficultyGrade)
    chain = prompt | structured_llm

    result = chain.invoke({})
    state["difficulty"] = result.difficulty.lower().strip()

    return state

In [None]:
def route_based_on_difficulty(state: DifficultyState) -> Literal["thinking", "regular"]:
    if state["difficulty"] == "very hard":
        return "thinking"
    else:
        return "regular"

In [None]:
def call_model_thinking(state: DifficultyState) -> DifficultyState:
    user_prompt = state["messages"][-1].content

    ai_response = llm_thinking.invoke([HumanMessage(content=user_prompt)])
    state["messages"].append(ai_response)

    return state


def call_model_regular(state: DifficultyState) -> DifficultyState:
    user_prompt = state["messages"][-1].content

    ai_response = llm_regular.invoke([HumanMessage(content=user_prompt)])
    state["messages"].append(ai_response)

    return state

In [None]:
workflow = StateGraph(DifficultyState)

workflow.add_node("classify_difficulty", classify_difficulty)
workflow.add_node("call_model_thinking", call_model_thinking)
workflow.add_node("call_model_regular", call_model_regular)

In [None]:
workflow.set_entry_point("classify_difficulty")
workflow.add_conditional_edges(
    "classify_difficulty",
    route_based_on_difficulty,
    {
        "thinking": "call_model_thinking",
        "regular": "call_model_regular",
    },
)
workflow.add_edge("call_model_thinking", END)
workflow.add_edge("call_model_regular", END)


graph = workflow.compile()

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        graph.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

In [None]:
graph.invoke(input={"messages": [HumanMessage(content="What is 2+2?")]})

In [None]:
result = graph.invoke(
    input={
        "messages": [
            HumanMessage(
                content="Could you provide a detailed proof of Fermat's Last Theorem?"
            )
        ]
    }
)

In [None]:
result

In [None]:
result["messages"][1].content[1]["text"]