In [1]:
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv
import os

_=load_dotenv(find_dotenv())
DEEPSEEK_API = os.getenv("DEEPSEEK_API")
BASE_URL = os.getenv("DEEPSEEK_URL")
MODEL_NAME = os.getenv("DEEPSEEK_MODEL")
llm = ChatOpenAI(api_key=DEEPSEEK_API, base_url=BASE_URL, model=MODEL_NAME)

In [2]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI

tagging_prompt = ChatPromptTemplate.from_template(
    """
Extract the desired information from the following passage.

Only extract the properties mentioned in the 'Classification' function.

Passage:
{input}
"""
)


class Classification(BaseModel):
    sentiment: str = Field(description="The sentiment of the text")
    aggressiveness: int = Field(
        description="How aggressive the text is on a scale from 1 to 10"
    )
    language: str = Field(description="The language the text is written in")


# llm=llm.with_structured_output(Classification)
tagging_chain = tagging_prompt | llm

In [3]:
from typing import Optional

from langchain_core.pydantic_v1 import BaseModel, Field

In [4]:
from typing import Optional

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field

# Define a custom prompt to provide instructions and any additional context.
# 1) You can add examples into the prompt template to improve extraction quality
# 2) Introduce additional parameters to take context into account (e.g., include metadata
#    about the document from which the text was extracted.)
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are an expert extraction algorithm. "
            "Only extract relevant information from the text. "
            "If you do not know the value of an attribute asked to extract, "
            "return null for the attribute's value.",
        ),
        # Please see the how-to about improving performance with
        # reference examples.
        # MessagesPlaceholder('examples'),
        ("human", "{text}"),
    ]
)

In [5]:
prompt_performance_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are an expert extraction algorithm. "
            "Only extract relevant information from the text. "
            "If you do not know the value of an attribute asked to extract, "
            "return null for the attribute's value.",
        ),
        # Please see the how-to about improving performance with
        # reference examples.
        # MessagesPlaceholder('examples'),
        ("human", "{text},现在日期是{date}"),
    ]
)

In [6]:
import sys
sys.path.append('..')  # 将 src 目录添加到 PYTHONPATH  # 假设当前工作目录是notebook目录
print(sys.path)

['/home/yepeng/miniconda3/envs/chatbi/lib/python312.zip', '/home/yepeng/miniconda3/envs/chatbi/lib/python3.12', '/home/yepeng/miniconda3/envs/chatbi/lib/python3.12/lib-dynload', '', '/home/yepeng/.local/lib/python3.12/site-packages', '/home/yepeng/miniconda3/envs/chatbi/lib/python3.12/site-packages', '/home/yepeng/miniconda3/envs/chatbi/lib/python3.12/site-packages/setuptools/_vendor', '..']


In [7]:
from entity.extraction import PerformanceQuerySchema

In [8]:


runnable =prompt_performance_template | llm.with_structured_output(schema=PerformanceQuerySchema)

In [None]:
# text = "去年国贸能化公司的利润率"
# runnable.invoke({"text": text,"date":"2024-08-25"}).dict()

In [9]:
from entity.extraction_example import Example, tool_example_to_messages
from entity.extraction import *

examples = [
    (
        "去年集团利润率为负的公司,当前日期是2024-08-25，查询用户为集团用户",
        PerformanceQuerySchema(
            indicator="GROSS_MARGIN_RATE",
            aggregation="YEAR",
            start_time="2023-01-01",
            end_time="2023-12-31",
            scope="GROUP",
            sort_type="DESC",
            operator="<",
            value="0",
        ),
    ),
    (
        "国贸能化公司今年上半年的销售额大于1000万的月份,当前日期是2024-08-25",
        PerformanceQuerySchema(
            indicator="SALES",
            aggregation="MONTH",
            start_time="2024-01-01",
            end_time="2024-06-30",
            scope="GROUP",
            sort_type="DESC",
            operator=">",
            value="10000000",
            company_name="国贸能化",
        ),
    ),
]


messages = []

for text, tool_call in examples:
    messages.extend(
        tool_example_to_messages({"input": text, "tool_calls": [tool_call]})
    )

In [10]:
example_prompt = prompt.invoke({"text": "this is some text", "examples": messages})

for message in example_prompt.messages:
    print(f"{message.type}: {message}")

system: content="You are an expert extraction algorithm. Only extract relevant information from the text. If you do not know the value of an attribute asked to extract, return null for the attribute's value."
human: content='this is some text'


In [11]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from prompt import EXTRACTION_PROMPT


In [12]:
runnable_with_examples = EXTRACTION_PROMPT | llm.with_structured_output(
    schema=PerformanceQuerySchema,
    method="function_calling",
    include_raw=False,
)

messages = []

for text, tool_call in examples:
    messages.extend(
        tool_example_to_messages({"input": text, "tool_calls": [tool_call]})
    )

In [13]:
messages

[HumanMessage(content='去年集团利润率为负的公司,当前日期是2024-08-25，查询用户为集团用户'),
 AIMessage(content='', tool_calls=[{'name': 'PerformanceQuerySchema', 'args': {'indicator': 'GROSS_MARGIN_RATE', 'aggregation': 'YEAR', 'start_time': '2023-01-01', 'end_time': '2023-12-31', 'scope': 'GROUP', 'sort_type': 'DESC', 'operator': '<', 'value': '0', 'company_name': None}, 'id': 'a7c2e8f1-4fe1-42c6-b487-8a994958dbd6', 'type': 'tool_call'}]),
 ToolMessage(content='You have correctly called this tool.', tool_call_id='a7c2e8f1-4fe1-42c6-b487-8a994958dbd6'),
 HumanMessage(content='国贸能化公司今年上半年的销售额大于1000万的月份,当前日期是2024-08-25'),
 AIMessage(content='', tool_calls=[{'name': 'PerformanceQuerySchema', 'args': {'indicator': 'SALES', 'aggregation': 'MONTH', 'start_time': '2024-01-01', 'end_time': '2024-06-30', 'scope': 'GROUP', 'sort_type': 'DESC', 'operator': '>', 'value': '10000000', 'company_name': '国贸能化'}, 'id': '4420fadd-b986-46b2-b9b3-6a166f0c686d', 'type': 'tool_call'}]),
 ToolMessage(content='You have correctly called 

In [14]:
from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    HumanMessage,
    SystemMessage,
    ToolMessage,
)
company_name_examples=[HumanMessage("company name examples:湖北国贸能源化工有限公司,湖北国贸金属矿产有限公司,湖北国贸汽车有限公司,湖北国际贸易集团有限公司,湖北国贸农产品有限公司,武汉鼎联丰国际贸易有限公司,湖北国贸农产品有限公司武汉分公司,湖北南方大集实业有限公司,湖北南方大集实业有限公司东西湖分公司,湖北南方大集实业有限公司慈惠分公司,湖北南方大集实业有限公司江汉分公司,湖北南方大集实业有限公司能源分公司,湖北南方工贸有限公司,湖北南方集团有限公司,湖北国贸供应链管理有限公司,湖北华中能源发展有限公司,湖北国贸汽车有限公司红安分公司,company_name如果要取值，提取后的名称必须从例子里选择，如果没有相符的公司名则返回company_name='company_name_not_found'")]

In [16]:
text = "去年集团利润率为负的公司,当前日期是2024-08-25，查询用户为集团用户"
print(runnable_with_examples.invoke({"text": text, "examples": messages,"company_name_example": company_name_examples}).dict())


{'indicator': 'GROSS_MARGIN_RATE', 'aggregation': 'YEAR', 'start_time': '2023-01-01', 'end_time': '2023-12-31', 'scope': 'GROUP', 'sort_type': 'DESC', 'operator': '<', 'value': '0', 'company_name': 'company_name_not_found'}


In [18]:
text = "今年国贸金属矿公司的销售额,当前日期是2024-08-25，查询用户为集团用户"
print(runnable_with_examples.invoke({"text": text, "examples": messages,"company_name_example": company_name_examples}).dict())

{'indicator': 'SALES', 'aggregation': 'YEAR', 'start_time': '2024-01-01', 'end_time': '2024-12-31', 'scope': 'COMPANY', 'sort_type': None, 'operator': None, 'value': None, 'company_name': '湖北国贸金属矿产有限公司'}
