# 加载环境变量

In [1]:
import os
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())
# Optional: Configure tracing to visualize and debug the agent

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Tool Use"

# 定义模型

In [2]:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini", base_url="https://api.chsdw.top/v1")

# 定义Tool

In [13]:
from pydantic import Field
from langchain_core.tools import BaseTool
from langchain_google_community import GoogleSearchAPIWrapper
from langchain_community.utilities import BingSearchAPIWrapper, BraveSearchWrapper, WikipediaAPIWrapper
from langchain_community.tools import WikipediaQueryRun
from langchain_community.tools.wikidata.tool import WikidataAPIWrapper, WikidataQueryRun
import requests

import numexpr as ne 
from langchain_core.tools import Tool
from langchain_experimental.utilities import PythonREPL

from googleapiclient import discovery
import json
from azure.core.credentials import AzureKeyCredential
from azure.ai.contentsafety.models import TextCategory, AnalyzeTextOptions
from azure.ai.contentsafety import ContentSafetyClient
from azure.core.exceptions import HttpResponseError

import diskcache as dc


class GoogleSearchTool(BaseTool):
    cache = Field(init=True)
    google = Field(init=True)
    def __init__(self, name: str = "google_search", description: str = "Search Google for recent results. Input should be a search query. ", cache_dir: str = ".cache/google_knowledge_graph_tool"):
        super().__init__(name=name, description=description, cache_dir=cache_dir)
        self.cache = dc.Cache(cache_dir)
        self.google = GoogleSearchAPIWrapper()

    def _run(self, query: str) -> str:
        if query in self.cache:
            print("Cache hit for text:", query)
            return self.cache[query]
        else:
            result = self.google.results(query=query, num_results=1)
            self.cache[query] = result
        return result

class BingSearchTool(BaseTool):
    cache = Field(init=True)
    bing = Field(init=True)
    def __init__(self, name: str = "bing_search", description: str = "Search Bing for recent results. Input should be a search query. ",cache_dir: str = ".cache/bing_search_tool"):
        super().__init__(name=name, description=description, cache_dir=cache_dir)
        self.cache = dc.Cache(cache_dir)
        self.bing = BingSearchAPIWrapper()

    def _run(self, query: str) -> str:
        if query in self.cache:
            print("Cache hit for text:", query)
            return self.cache[query]
        else:
            result = self.bing.results(query=query, num_results=1)
            self.cache[query] = result
        return result
    
class BraveSearchTool(BaseTool):
    cache = Field(init=True)
    brave = Field(init=True)
    def __init__(self, name: str = "brave_search", 
                 description: str = "A search engine. useful for when you need to answer questions about current events. Input should be a search query. ", 
                 cache_dir: str = ".cache/brave_search_tool", api_key: str = None, search_kwargs: dict = {"count": 1}):
        super().__init__(name=name, description=description, cache_dir=cache_dir)
        self.cache = dc.Cache(cache_dir)
        self.brave = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {})
    
    def _run(self, query: str) -> str:
        if query in self.cache:
            print("Cache hit for text:", query)
            return self.cache[query]
        else:
            result = self.brave.run(query=query)
            self.cache[query] = result
        return result

class WikidataTool(WikidataQueryRun):
    name = "wikidata"

    cache = Field(init=True)

    def __init__(self, api_wrapper: WikidataAPIWrapper = WikidataAPIWrapper(), cache_dir: str = ".cache/wikidata_tool"):
        super().__init__(api_wrapper=api_wrapper)
        self.cache = dc.Cache(cache_dir)

    def _run(self, text: str) -> str:
        if text in self.cache:
            print("Cache hit for text:", text)
            return self.cache[text]
        
        results = super()._run(text)

        self.cache[text] = results
        
        return results

class WikipediaTool(WikipediaQueryRun):
    name = "wikipedia"

    cache = Field(init=True)

    def __init__(self, api_wrapper: WikipediaAPIWrapper = WikipediaAPIWrapper(top_k_results=1), cache_dir: str = ".cache/wikipedia_tool"):
        super().__init__(api_wrapper=api_wrapper)
        self.cache = dc.Cache(cache_dir)

    def _run(self, text: str) -> str:
        if text in self.cache:
            print("Cache hit for text:", text)
            return self.cache[text]
        
        results = super()._run(text)

        self.cache[text] = results
        
        return results
    
class GoogleKnowledgeGraphTool(BaseTool):
    name = "google_knowledge_graph"
    description = (
        "This tool searches for entities in the Google Knowledge Graph. "
        "It provides information about people, places, things, and concepts. "
        "Input should be an entity name."
    )
    api_key: str = Field(..., description="Google Knowledge Graph Search API key")
    cache = Field(init=True)

    def __init__(self, api_key: str, cache_dir: str = ".cache/google_knowledge_graph_tool"):
        super().__init__(api_key=api_key)
        self.cache = dc.Cache(cache_dir)

    def _run(self, query: str, limit: int = 1) -> str:
        if query in self.cache:
            print("Cache hit for text:", query)
            return self.cache[query]

        service_url = "https://kgsearch.googleapis.com/v1/entities:search"
        params = {
            "query": query,
            "limit": limit,
            "indent": True,
            "key": self.api_key,
        }

        try:
            response = requests.get(service_url, params=params)
            response.raise_for_status()  # Raise an exception for HTTP errors
        except requests.RequestException as e:
            return f"Failed to retrieve data: {str(e)}"

        data = response.json()
        self.cache[query] = data
        return data

    def _format_results(self, data: dict) -> str:
        results = data.get("itemListElement", [])
        formatted_results = []

        for element in results:
            result = element.get("result", {})
            name = result.get("name", "N/A")
            score = element.get("resultScore", 0)
            formatted_results.append(f"{name} ({score})")

        return "\n".join(formatted_results) if formatted_results else "No results found."

class CalculatorTool(BaseTool):
    name = "calculator"
    description = ("Useful when you need to calculate the value of a mathematical expression, including basic arithmetic operations. "
                   "Use this tool for math operations. "
                   "Input should strictly follow the numuxpr syntax. ")

    def _run(self, expression: str):
      try:
        result = ne.evaluate(expression).item()
        return f"The result of the expression of <{expression}> is: {result}."
      except Exception as e:
        # return "This is not a numexpr valid syntax. Try a different syntax."
        return f"Error in calculation: {str(e)}"
      
class PythonREPLTool(BaseTool):
    name = "python_repl"
    description = ("A Python shell. Use this to execute python commands. Input should be a valid python command. "
                   "If you want to see the output of a value, you should print it out with `print(...)`.")
    def _run(self, code: str) -> str:
        try:
            result = PythonREPL().run(code)
            return result
        except Exception as e:
            return f"Error: {str(e)}"

class PerspectiveTool(BaseTool):
    name = "perspective"
    description = ("This tool analyzes text for safety using Google Perspective API"
                   "It detects categories such as hate, self-harm, sexual content, and violence."
                   )
    # api_key: str = Field(default=os.environ.get("PERSPECTIVE_API_KEY"), description="Google Perspective API key")
    client = Field(init=True)
    cache = Field(init=True)

    def __init__(self, api_key: str=os.environ.get("GOOGLE_PERSPECTIVE_API_KEY"), cache_dir: str = ".cache/perspective_tool"):
        super().__init__()
        self.client = discovery.build("commentanalyzer", "v1alpha1", 
                                      discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1",
                                      developerKey=api_key,
                                      static_discovery=False,)
        self.cache = dc.Cache(cache_dir)
        
    def _run(self, text: str):
        if text in self.cache:
            print("Cache hit for text:", text)
            return self.cache[text]

        PERSPECTIVE_API_ATTRIBUTES = (
            'TOXICITY',
            'SEVERE_TOXICITY',
            'IDENTITY_ATTACK',
            'INSULT',
            'THREAT',
            'PROFANITY',
            'SEXUALLY_EXPLICIT',
            'FLIRTATION'
        )

        analyze_request = {
            'comment': {'text': text},
            'requestedAttributes': {attr: {} for attr in PERSPECTIVE_API_ATTRIBUTES}
        }

        response = self.client.comments().analyze(body=analyze_request).execute()
        # Cache the result
        self.cache[text] = response
        return response

class AzureContentModerationTool(BaseTool):
    name = "azure_content_moderation"
    description = (
        "This tool analyzes text for safety using Azure AI Content Safety. "
        "It detects categories such as hate, self-harm, sexual content, and violence. "
        "Input should be a text string. "
    )

    key: str = Field(default=os.environ.get("AZURE_CONTENT_SAFETY_KEY"), description="Azure Content Safety API key")
    endpoint: str = Field(default=os.environ.get("AZURE_CONTENT_SAFETY_ENDPOINT"), description="Azure Content Safety endpoint")
    client: ContentSafetyClient = Field(init=True)
    cache = Field(init=True)

    def __init__(self, endpoint: str = None, key: str = None, cache_dir: str = ".cache/azure_content_moderation_tool"):
        super().__init__(key=key, endpoint=endpoint)
        self.client = ContentSafetyClient(self.endpoint, AzureKeyCredential(self.key))
        self.cache = dc.Cache(cache_dir)

    def _run(self, text: str) -> str:
        if text in self.cache:
            print("Cache hit for text:", text)
            return self.cache[text]
        
        request = AnalyzeTextOptions(text=text)

        try:
            response = self.client.analyze_text(request)
        except HttpResponseError as e:
            error_message = "Analyze text failed."
            if e.error:
                error_message += f" Error code: {e.error.code}. Error message: {e.error.message}"
            return error_message

        results = {
            "hate": self._get_severity(response, TextCategory.HATE),
            "self_harm": self._get_severity(response, TextCategory.SELF_HARM),
            "sexual": self._get_severity(response, TextCategory.SEXUAL),
            "violence": self._get_severity(response, TextCategory.VIOLENCE),
        }

        formatted_results = self._format_results(results)

        self.cache[text] = formatted_results

        return formatted_results

    def _get_severity(self, response, category: TextCategory):
        result = next((item for item in response.categories_analysis if item.category == category), None)
        return result.severity if result else "Not found"

    def _format_results(self, results: dict) -> str:
        return "\n".join(f"{category.capitalize()} severity: {severity} outof 7." for category, severity in results.items())



# 测试tool

In [14]:
google_search = GoogleSearchTool()
# print(google_search.run("What is the capital of France?"))
bing_search = BingSearchTool()
# print(bing_search.run("What is the capital of China?"))
brave = BraveSearchTool(api_key=os.environ.get("BRAVE_API_KEY"))
# print(brave.run("What is the capital of India?"))
azure_content_moderation = AzureContentModerationTool(endpoint=os.environ.get("AZURE_CONTENT_SAFETY_ENDPOINT"), key=os.environ.get("AZURE_CONTENT_SAFETY_KEY"))
# print(azure_content_moderation.run("I hate you"))
wikidata = WikidataTool()
# print(wikidata.run("Japan"))
wikipedia = WikipediaTool()
# print(wikipedia.run("Japan"))
google_knowledge_graph = GoogleKnowledgeGraphTool(api_key=os.environ.get("GOOGLE_API_KEY"))
# google_knowledge_graph.run("China")
calculator = CalculatorTool()
# calculator.run("2+2")
python_repl = PythonREPLTool()
# python_repl.run("print('Hello, World!')")
perspective = PerspectiveTool()
print(perspective.run("I hate you"))

Cache hit for text: I hate you
{'attributeScores': {'INSULT': {'spanScores': [{'begin': 0, 'end': 10, 'score': {'value': 0.4378843, 'type': 'PROBABILITY'}}], 'summaryScore': {'value': 0.4378843, 'type': 'PROBABILITY'}}, 'TOXICITY': {'spanScores': [{'begin': 0, 'end': 10, 'score': {'value': 0.6827122, 'type': 'PROBABILITY'}}], 'summaryScore': {'value': 0.6827122, 'type': 'PROBABILITY'}}, 'THREAT': {'spanScores': [{'begin': 0, 'end': 10, 'score': {'value': 0.043955702, 'type': 'PROBABILITY'}}], 'summaryScore': {'value': 0.043955702, 'type': 'PROBABILITY'}}, 'IDENTITY_ATTACK': {'spanScores': [{'begin': 0, 'end': 10, 'score': {'value': 0.28883415, 'type': 'PROBABILITY'}}], 'summaryScore': {'value': 0.28883415, 'type': 'PROBABILITY'}}, 'SEVERE_TOXICITY': {'spanScores': [{'begin': 0, 'end': 10, 'score': {'value': 0.0394905, 'type': 'PROBABILITY'}}], 'summaryScore': {'value': 0.0394905, 'type': 'PROBABILITY'}}, 'SEXUALLY_EXPLICIT': {'spanScores': [{'begin': 0, 'end': 10, 'score': {'value': 0.

# 构造ToolList

In [15]:
def get_tools_descriptions(tools:list):
    tools_descriptions = []
    for tool in tools:
        tools_descriptions.append(f"{tool.name} - {tool.description}")
    return "\n\n".join(tools_descriptions)

def get_tools_dict(tools:list)->dict:
    tools_dict = {}
    for tool in tools:
        tools_dict[tool.name.lower()] = tool
    return tools_dict

tools = [google_search, bing_search, brave, azure_content_moderation, wikidata, wikipedia, google_knowledge_graph, calculator, python_repl, perspective]

tools_descriptions = get_tools_descriptions(tools)
tools_dict = get_tools_dict(tools)

In [24]:
from typing import List

from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field


class Tool(BaseModel):
    """Information about a person."""

    tool: str = Field(..., description="The tool used to find information which are useful for the question.")
    input: str = Field(..., description="The input for the tool.")
    


class ToolList(BaseModel):
    """Identifying all the tools needed to answer the question."""

    tool_list: List[Tool]


# Set up a parser
parser = PydanticOutputParser(pydantic_object=ToolList)

# Prompt
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You have access to the following tools:\n\n"
            "{tools_description}\n\n"
            "--------------------------------\n"
            "Inspect each tool's description and the input requirements carefully."
            "Identify all the tools you may need to use and the corresponding tool input text based on the user's current step. "
            "You should also consider the overall context of the user's question. "
            "{format_instructions}",
        ),
        ("human", "{query}"),
    ]
).partial(tools_description=tools_descriptions ,format_instructions=parser.get_format_instructions())

In [22]:
query = "If a bag of marbles costs $20 and the price increases by 20% of the original price every two months, how much would a bag of marbles cost after 36 months?"

print(prompt.invoke(query).to_string())


System: You have access to the following tools:

google_search - Search Google for recent results. Input should be a search query. 

bing_search - Search Bing for recent results. Input should be a search query. 

brave_search - A search engine. useful for when you need to answer questions about current events. Input should be a search query. 

azure_content_moderation - This tool analyzes text for safety using Azure AI Content Safety. It detects categories such as hate, self-harm, sexual content, and violence. Input should be a text string. 

wikidata - A wrapper around Wikidata. Useful for when you need to answer general questions about people, places, companies, facts, historical events, or other subjects. Input should be the exact name of the item you want information about or a Wikidata QID.

wikipedia - A wrapper around Wikipedia. Useful for when you need to answer general questions about people, places, companies, facts, historical events, or other subjects. Input should be a sea

In [None]:
chain = prompt | llm | parser

print(chain.invoke({"query": query}))

In [None]:
tools_dict['google knowledge graph'].run('Google Knowledge Graph')
tools_dict['wikipedia'].run('How to make steak recipe')

# Plan

In [None]:
from langchain_community.tools.tavily_search import TavilySearchResults

tools = [TavilySearchResults(max_results=3)]

In [None]:
import operator
from typing import Annotated, List, Tuple, TypedDict


class PlanExecute(TypedDict):
    input: str
    plan: List[str]
    past_steps: Annotated[List[Tuple], operator.add]
    response: str

from langchain_core.pydantic_v1 import BaseModel, Field


class Plan(BaseModel):
    """Plan to follow in future"""

    steps: List[str] = Field(
        description="different steps to follow, should be in sorted order"
    )

In [None]:
from langchain_core.output_parsers import PydanticOutputParser, JsonOutputParser

# Set up a parser + inject instructions into the prompt template.
parser = JsonOutputParser(pydantic_object=Plan)

print(parser.get_format_instructions())

In [None]:
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate

planner_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "For the given objective, come up with a simple step by step plan. "
            "This plan should involve individual tasks, that if executed correctly will yield the correct answer. Do not add any superfluous steps. "
            "The result of the final step should be the final answer. Make sure that each step has all the information needed - do not skip steps. \n"
        ),
        ("human", "Here is the question you need to plan. \n{messages}\n{format_instructions}"),
    ]
).partial(format_instructions=parser.get_format_instructions())


In [None]:
parser.get_format_instructions()

In [None]:
planner = planner_prompt | ChatOpenAI(
    model="gpt-4o-mini", temperature=0, base_url="https://api.chsdw.top/v1"
) | parser
response = planner.invoke(input={"messages": "Which sports event was first held at Worcester, Massachusetts in 1927?"})

In [None]:
print(response)

In [None]:
plan = parser.parse(response.content)

In [None]:
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI


# Define your desired data structure.
class Joke(BaseModel):
    setup: str = Field(description="question to set up a joke")
    punchline: str = Field(description="answer to resolve the joke")


# And a query intented to prompt a language model to populate the data structure.
joke_query = "Tell me a joke."

# Set up a parser + inject instructions into the prompt template.
parser = JsonOutputParser(pydantic_object=Joke)

prompt = PromptTemplate(
    template="Answer the user query.\n{format_instructions}\n{query}\n",
    input_variables=["query"],
    partial_variables={"format_instructions": parser.get_format_instructions()},
)

chain = prompt | llm | parser

chain.invoke({"query": joke_query})