![模型结构](image.png)

In [None]:
# code的生成格式
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
import os

load_dotenv()

# api_key = os.getenv("ZHIPU_API_KEY")
api_key = os.getenv("OPENAI_API_KEY")

class CodeSchema(BaseModel):
    prefix: str = Field(description="这些是代码的解释，需要用中文回答")
    imports: str = Field(description="这些是代码需要导入的库")
    code: str = Field(description="这些是代码本身的内容，不包含解释和导入的库")

code_gen_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            """You are a coding assistant with expertise in Python. \n 
    Here is a full set of documentation:  \n ------- \n  {context} \n ------- \n Answer the user 
    question based on the above provided documentation(if it is not provided, you should generate the code based on the user question). \n
    Ensure any code you provide can be executed \n 
    with all required imports and variables defined. Structure your answer with a description of the code solution. \n
    Then list the imports. And finally list the functioning code block. Here is the user question:""",
        ),
        ("placeholder", "{messages}"),
    ]
    # placehold 后面的messages是数组
)

llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=api_key)
# llm = ChatZhipuAI(api_key=api_key, model="glm-4-flash")

code_gen_chain = code_gen_prompt | llm.with_structured_output(CodeSchema, include_raw=True)
# 这里如果加上 include_raw=True 会返回原始的响应，包括所有响应头和响应体

res = code_gen_chain.invoke({"context":'', "messages": [("user", "请生成一个简单的Python代码, 用于计算两个数的和")]})

    



In [85]:
res['parsed'].code

'def add_numbers(num1, num2):\n    return num1 + num2'

**考虑到llm可能并没有完成我们的要求 完成格式化输出 我们可以定义一个fall_back chain**

In [None]:
# 首先从上面的res中看看'parsed'和'parsing_error'的值
# res -> {'parsed': None, 'parsing_error': None, 'raw': None}
def check_output(output):
    if output['parsing_error'] is not None:
        print("Parsing error")
        raw_output = output['raw']
        error = output['parsing_error']
        raise ValueError(f"解析时出现错误,确保你调用了CodeSchema工具。Parsing error: {error}\nRaw output: {raw_output}")
    elif output['parsed'] is None:
        print("Failed to invoke the tool")
        raise ValueError(f"调用工具出现错误,你没有调用CodeSchema工具,确保你调用了CodeSchema工具来保证格式化输出")
    return output

code_chain_raw = code_gen_prompt | llm.with_structured_output(CodeSchema, include_raw=True) | check_output



In [87]:
res = code_chain_raw.invoke({"context":'', "messages": [("user", "请生成一个简单的Python代码，用于计算两个数的和")]})

In [88]:
res['parsed'].code

'def sum_two_numbers(a, b):\n    return a + b\n\nresult = sum_two_numbers(5, 3)\nprint(result)'

**这个 fallback 会在原始链（code_chain_claude_raw）执行失败时被调用，而失败时 LangChain 会自动将错误和输入打包成一个 inputs 传进去。**  
`inputs = {`  
    `"context": ...,        # 上一个链条的 context 变量`  
    `"messages": [...],     # 当前对话历史`  
    `"error": Exception(),  # 上一个步骤的错误（比如 parsing_error）`  
`}`

In [89]:
def insert_error(inputs):
    error = inputs['error']
    messages = inputs['messages']
    messages += [("assistant", f"你必须修复下面这个错误Error: {error}， 你必须调用工具")]
    # 注意这里使用assistant 让ai知道错误，并调用工具
    return {"context": inputs['context'],
        "messages": messages}

fall_back_chain = insert_error | code_chain_raw
retry_n = 3
code_gen = code_chain_raw.with_fallbacks(fallbacks=[fall_back_chain]*retry_n, exception_key="error")
# 定义inputs里面'error'的key
# code_gen.invoke({"context":'', "messages": [("user", "请生成一个简单的Python代码，用于计算两个数的和")]})

# 提取原始数据中的格式化输出
def prased_output(output):
    return output['parsed']

chain_with_retry = code_gen | prased_output
chain_with_retry.invoke({"context":'', "messages": [("user", "请生成一个简单的Python代码，用于计算两个数的和")]})

CodeSchema(prefix='calculate_sum', imports=' ', code='def calculate_sum(num1, num2):\n    return num1 + num2\n\nresult = calculate_sum(5, 3)\nprint(result)')

**实现Langgraph**

In [90]:
# 首先定义state状态
from typing import TypedDict, List, Optional, Tuple

class State(TypedDict):
    messages: List[Tuple[str, str]]
    error: str
    generation: str
    iterations: int


In [92]:
# 生成代码的node
def generate(state:State)->State:
    messages = state['messages']
    error = state['error']
    iterations = state['iterations']
    if error == 'yes':
        messages += [("user", f"重试一次， 你必须调用工具，并且生成格式话的代码")]
        
        code_solution = chain_with_retry.invoke({"context":'', "messages": messages})
        messages += [
        (
            "assistant",
            f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        )
        ]
    else:
        code_solution = chain_with_retry.invoke({"context":'', "messages": messages})
        messages += [
        (
            "assistant",
            f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        )
        ]
    iterations = iterations + 1
    return {"generation": code_solution, "messages": messages, "iterations": iterations}


In [93]:
def code_check(state:State)->State:
    messages = state['messages']
    error = state['error']
    iterations = state['iterations']
    code_solution = state['generation']
    imports = code_solution.imports
    code = code_solution.code
    # Check imports
    try:
        exec(imports)
    except Exception as e:
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [("user", f"Your solution failed the import test: {e}")]
        messages += error_message
        return {
            "generation": code_solution,
            "messages": messages,
            "iterations": iterations,
            "error": "yes",
        }

    # Check execution
    try:
        exec(imports + "\n" + code)
    except Exception as e:
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [("user", f"Your solution failed the code execution test: {e}")]
        messages += error_message
        return {
            "generation": code_solution,
            "messages": messages,
            "iterations": iterations,
            "error": "yes",
        }

    # No errors
    print("---CODE CHECK: SUCCESS---")
    return {
        "generation": code_solution,
        "messages": messages,
        "iterations": iterations,
        "error": "no",
    }


In [94]:
def reflect(state:State)->State:
    # 对上面填入的错误进行反思
    # State
    messages = state["messages"]
    iterations = state["iterations"]
    code_solution = state["generation"]

    # Prompt reflection

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": '', "messages": messages}
    )
    messages += [("assistant", f"这是出现的error的反思: {reflections}")]
    return {"generation": code_solution, "messages": messages, "iterations": iterations}

In [95]:
max_iterations = 3
flag = "reflect"

def decide_to_finish(state: State):
    error = state["error"]
    iterations = state["iterations"]

    if error == "no" or iterations == max_iterations:
        print("---DECISION: FINISH---")
        return "end"
    else:
        print("---DECISION: RE-TRY SOLUTION---")
        if flag == "reflect":
            return "reflect"
        else:
            return "generate"

In [96]:
# 构建工作图
from langgraph.graph import StateGraph, START, END

workflow = StateGraph(State)

workflow.add_node("generate", generate)
workflow.add_node("code_check", code_check)
workflow.add_node("reflect", reflect)

workflow.add_edge(START, "generate")
workflow.add_edge("generate", "code_check")
workflow.add_conditional_edges(
    "code_check",
    decide_to_finish,
    {
        "end": END,
        "reflect": "reflect",
        "generate": "generate",
    },
)
workflow.add_edge("reflect", "generate")
app = workflow.compile()

In [97]:
question = "如何解决一个N皇后问题"
solution = app.invoke({"messages": [("user", question)], "iterations": 0, "error": "", "generation": ""})

---CODE CHECK: SUCCESS---
---DECISION: FINISH---


In [99]:
solution['generation'].code

"def solve_n_queens(n: int) -> List[List[str]]:\n    def is_safe(board, row, col):\n        for i in range(row):\n            if board[i] == col or board[i] - i == col - row or board[i] + i == col + row:\n                return False\n        return True\n\n    def backtrack(board, row):\n        if row == n:\n            result.append([''.join(['Q' if j == col else '.' for j in board]) for col in board])\n            return\n        for col in range(n):\n            if is_safe(board, row, col):\n                board[row] = col\n                backtrack(board, row + 1)\n                board[row] = 0\n\n    result = []\n    board = [0] * n\n    backtrack(board, 0)\n    return result"