In [5]:
"""
基于 Graph API 的提示链工作流实现
"""

from typing_extensions import TypedDict

from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END

from dotenv import load_dotenv

load_dotenv()


class State(TypedDict):
    topic: str
    joke: str
    improved_joke: str
    final_joke: str


llm = ChatOpenAI(model="Qwen/Qwen2.5-7B-Instruct")


def generate_joke(state: State):
    msg = llm.invoke(f"写一个关于{state['topic']}的简短笑话")
    return {"joke": msg.content}


def check_punchline(state: State):
    if "?" in state["joke"] or "!" in state["joke"]:
        return "Fail"
    return "Pass"


def improve_joke(state: State):
    msg = llm.invoke(f"通过添加文字游戏使这个笑话更有趣：{state['joke']}")
    return {"improved_joke": msg.content}


def polish_joke(state: State):
    msg = llm.invoke(f"为这个笑话添加一个令人惊讶的转折：{state['improved_joke']}")
    return {"final_joke": msg.content}


workflow = StateGraph(State)

workflow.add_node("generate_joke", generate_joke)
workflow.add_node("improve_joke", improve_joke)
workflow.add_node("polish_joke", polish_joke)

workflow.add_edge(START, "generate_joke")
workflow.add_conditional_edges(
    "generate_joke", check_punchline, {"Fail": "improve_joke", "Pass": END}
)
workflow.add_edge("improve_joke", "polish_joke")
workflow.add_edge("polish_joke", END)

chain = workflow.compile()

state = chain.invoke({"topic": "cats"})
print("初始笑话：")
print(state["joke"])

if "improved_joke" in state:
    print("改进后的笑话：")
    print(state["improved_joke"])
    print("最终笑话：")
    print(state["final_joke"])

初始笑话：
当然可以！这里有一个简短的猫咪笑话：

为什么猫咪不会玩扑克牌？

因为它们总是把A牌当猫！
