In [1]:
import asyncio
import os
import time
from dataclasses import dataclass, field
from typing import List, Tuple

from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

from utils import generate_arithmetic_expression, re_parse_json, calculate_time_difference

# 加载 .env 文件
load_dotenv()

True

In [2]:
api_keys = os.getenv('API_KEY').split(",")

In [4]:
@dataclass
class LLMAPI:
    """
    cnt: 统计每一个token的调用次数
    """
    api_key: str
    uid: int
    cnt: int = 0
    llm: ChatOpenAI = field(init=False)  # 自动创建的对象，不需要用户传入

    def __post_init__(self):
        # 在这里初始化 llm 对象
        self.llm = self.create_llm()

    def create_llm(self):
        # 模拟创建 llm 对象的逻辑
        return ChatOpenAI(
            model="gpt-4o-mini",
            base_url="https://api.chatfire.cn/v1/",
            api_key=self.api_key,
        )

    async def agenerate(self, text):
        self.cnt += 1
        # return await self.llm.agenerate([text]).generations[0][0].text
        res = await self.llm.agenerate([text])
        # print(res)
        # return res.generations[0][0].text
        return res

In [5]:
async def call_llm(llm: LLMAPI, text: str):
    return await llm.agenerate(text)

In [6]:
async def run_api(keys: List[str], data: List[str]) -> Tuple[List[str], List[LLMAPI]]:
    # 创建LLM
    llms = [LLMAPI(api_key=key, uid=i) for i, key in enumerate(keys)]

    results = [call_llm(llms[i % len(llms)], text) for i, text in enumerate(data)]
    # 结果按序返回
    results = await asyncio.gather(*results)
    # for item in results:
    #     print(item.generations[0][0].text)
    # print(results)
    return results, llms

In [142]:
# results, llms = await run_api(api_keys, questions)

In [None]:
# for item in results:
#     print(item.generations[0][0].text)

In [18]:
prompt_template = """
    请将以下表达式的计算结果返回为 JSON 格式：
    {{
      "expression": "{question}",
      "result": ?
    }}
    """

questions = []
labels = []

for _ in range(90):
    question, label = generate_arithmetic_expression(4)
    questions.append(prompt_template.format(question=question))
    labels.append(label)

In [19]:
len(labels)

90

In [20]:
start_time = time.time()

# for jupyter
results, llms = await run_api(api_keys, questions)

# 运行程序
# results, llms = asyncio.run(run_api(api_keys, questions))
# results, llms = asyncio.run(run_api(api_keys[:1], questions))
right = 0
except_cnt = 0
not_equal = 0

for q, res, label in zip(questions, results, labels):
    res = res.generations[0][0].text
    try:
        res = re_parse_json(res)
        if res is None:
            except_cnt += 1
            continue

        res = res.get("result", None)
        if res is None:
            except_cnt += 1
            continue

        res = int(res)
        if res == label:
            right += 1
        else:
            not_equal += 1
    except Exception as e:
        print(e)
        print(f"question:{q}\nresult:{res}")

print("accuracy: {}%".format(right / len(questions) * 100))
end_time = time.time()
calculate_time_difference(start_time, end_time)

accuracy: 90.0%
executed in 00:00:13.345 (h:m:s.ms)


In [21]:
right, except_cnt, not_equal

(81, 0, 9)

In [22]:
13 / 30

0.43333333333333335