In [None]:
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv
import os
from agent import ExtractorRunnable

_ = load_dotenv(find_dotenv())
DEEPSEEK_API = os.getenv("DEEPSEEK_API")
BASE_URL = os.getenv("DEEPSEEK_URL")
MODEL_NAME = os.getenv("DEEPSEEK_MODEL")
llm = ChatOpenAI(
    api_key=DEEPSEEK_API, base_url=BASE_URL, model=MODEL_NAME, temperature=0.1
)

In [None]:
import sys
sys.path.append('..')  # 将 src 目录添加到 PYTHONPATH  # 假设当前工作目录是notebook目录

In [None]:
from entity.extraction import PerformanceQuerySchema
from prompt import (
    EXTRACTION_PROMPT,
    PERFORMANCE_OUTPUT_EXAMPLES_MESSAGES,
    COMPANY_NAME_EXAMPLES,
)
import datetime
date=datetime.datetime.now().strftime("%Y-%m-%d")

runnable_with_examples = EXTRACTION_PROMPT | llm.with_structured_output(
    schema=PerformanceQuerySchema,
    method="function_calling",
    include_raw=False,
)

In [None]:
print(
    runnable_with_examples.invoke(
        {
            "text": "去年集团利润率为负的公司",
            "date":date,
            "user_role":"集团用户",
            "examples": PERFORMANCE_OUTPUT_EXAMPLES_MESSAGES,
            "company_name_example": COMPANY_NAME_EXAMPLES,
        }
    ).dict()
)

In [None]:
print(
    runnable_with_examples.invoke(
        {
            "text": "今年国贸金属矿公司的销售额",
            "date": date,
            "user_role": "公司用户",
            "examples": PERFORMANCE_OUTPUT_EXAMPLES_MESSAGES,
            "company_name_example": COMPANY_NAME_EXAMPLES,
        }
    ).dict()
)

In [None]:
# 打开文件并读取内容
with open("../test/user_input.txt", "r", encoding="utf-8") as file:
    lines = file.readlines()

# 去掉每行末尾的换行符
lines = [line.strip() for line in lines]

# 打印列表以验证结果
print(lines)

In [None]:
import csv


data = []
for line in lines:
    try:
        result = runnable_with_examples.invoke(
            {
                "text": line,
                "examples": PERFORMANCE_OUTPUT_EXAMPLES_MESSAGES,
                "company_name_example": COMPANY_NAME_EXAMPLES,
            }
        ).dict()
    except:
        result = "Error"
    data.append([line, result])
    print(line, result)


# 写入 CSV 文件
with open("../test/output.csv", "w", newline="", encoding="utf-8") as csvfile:
    csvwriter = csv.writer(csvfile)

    # 写入表头（可选）
    csvwriter.writerow(["Line", "Result"])

    # 写入数据
    csvwriter.writerows(data)

print("数据已成功写入 output.csv 文件")