In [1]:
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv
import os
import sys
sys.path.append('..')  # 将 src 目录添加到 PYTHONPATH  # 假设当前工作目录是notebook目录
from agent import ExtractorRunnable
from entity import InputData
from pprint import pprint
from entity.extraction import PerformanceQuerySchema
from prompt import (
    EXTRACTION_PROMPT,
    PERFORMANCE_OUTPUT_EXAMPLES_MESSAGES,
    COMPANY_NAME_EXAMPLES,
)
import datetime
date=datetime.datetime.now().strftime("%Y-%m-%d")



In [3]:
extractor_runnable = ExtractorRunnable()
input_data = InputData(
    text="今年国贸金属矿公司的销售额",
    date=date,
    user_role="公司用户",
    examples=PERFORMANCE_OUTPUT_EXAMPLES_MESSAGES,
    company_name_example=COMPANY_NAME_EXAMPLES,
)
result = extractor_runnable.invoke(input_data,PerformanceQuerySchema)

In [4]:

pprint(result.dict())

{'aggregation': 'YEAR',
 'company_name': '湖北国贸金属矿产有限公司',
 'end_time': '2024-12-31',
 'indicator': 'SALES',
 'operator': None,
 'scope': 'COMPANY',
 'sort_type': 'DESC',
 'start_time': '2024-01-01',
 'value': None}


In [6]:
# 打开文件并读取内容
with open("../test/user_input.txt", "r", encoding="utf-8") as file:
    lines = file.readlines()

# 去掉每行末尾的换行符
lines = [line.strip() for line in lines]

# 打印列表以验证结果
pprint(lines)

['今年销售额是多少',
 '今年的采购额是多少',
 '今年的毛利润是多少',
 '今年毛利率怎么样',
 '今年能化公司经营状况如何',
 '业绩变动的趋势如何',
 '去年销售额是多少',
 '去年的采购额是多少',
 '去年的毛利润是多少',
 '去年毛利率怎么样',
 '去年能化公司经营状况如何',
 '上半年那个月的业绩最好',
 '哪个月的利润率低于百分之5',
 '哪个月的利润率为负的',
 '总销售额现在是多少',
 '总采购额目前累计了多少',
 '毛利润现在累计到多少了',
 '现在的毛利率是多少',
 '哪个月的销售额最高',
 '哪个月的毛利润最高',
 '哪个月的毛利率最低',
 '哪个子公司的销售额最突出',
 '哪个子公司的毛利润最高',
 '哪个子公司的毛利率表现最好',
 '今年集团的销售额同比增长超过5%的公司',
 '今年利润率为负的公司']


In [17]:
import csv


data = []
for line in lines:
    try:
        result = extractor_runnable.invoke(
            {
                "text": line,
                "date":date,
                "user_role":"集团用户",
                "examples": PERFORMANCE_OUTPUT_EXAMPLES_MESSAGES,
                "company_name_example": COMPANY_NAME_EXAMPLES,
            },
            PerformanceQuerySchema,
        )
    except Exception as e:
        print(e)
        result = "Error"

    data.append([line, str(result)])
    pprint(line)
    pprint(str(result))


# 写入 CSV 文件
with open("../test/output.csv", "w", newline="", encoding="utf-8") as csvfile:
    csvwriter = csv.writer(csvfile)

    # 写入表头（可选）
    csvwriter.writerow(["Line", "Result"])

    # 写入数据
    csvwriter.writerows(data)

print("数据已成功写入 output.csv 文件")

'今年销售额是多少'
("indicator='SALES' aggregation='YEAR' start_time='2024-01-01' "
 "end_time='2024-12-31' scope='GROUP' sort_type='DESC' operator=None "
 'value=None company_name=None')
'今年的采购额是多少'
("indicator='PROCUREMENT' aggregation='YEAR' start_time='2024-01-01' "
 "end_time='2024-12-31' scope='GROUP' sort_type='DESC' operator=None "
 'value=None company_name=None')
'今年的毛利润是多少'
("indicator='GROSS_PROFIT' aggregation='YEAR' start_time='2024-01-01' "
 "end_time='2024-12-31' scope='GROUP' sort_type='DESC' operator=None "
 'value=None company_name=None')
'今年毛利率怎么样'
("indicator='GROSS_MARGIN_RATE' aggregation='YEAR' start_time='2024-01-01' "
 "end_time='2024-12-31' scope='GROUP' sort_type='DESC' operator=None "
 'value=None company_name=None')
'今年能化公司经营状况如何'
("indicator='GROSS_PROFIT' aggregation='MONTH' start_time='2024-01-01' "
 "end_time='2024-08-31' scope='GROUP' sort_type='DESC' operator=None "
 "value=None company_name='湖北国贸能源化工有限公司'")
'业绩变动的趋势如何'
("indicator='SALES' aggregation='MONTH'