In [1]:
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage

import json
import re

In [2]:

# 星火认知大模型Spark Max的URL值，其他版本大模型URL值请前往文档（https://www.xfyun.cn/doc/spark/Web.html）查看
SPARKAI_URL = "wss://spark-api-n.xf-yun.com/v3.1/chat"
# SPARKAI_URL = "wss://spark-api-n.xf-yun.com/v3.1/chat"
# 星火认知大模型调用秘钥信息，请前往讯飞开放平台控制台（https://console.xfyun.cn/services/bm35）查看
SPARKAI_APP_ID = "1416480a"
SPARKAI_API_SECRET = "YmZlYTM4Nzg4ZjUyNmNiMjA5MTNjZmZj"
SPARKAI_API_KEY = "eb503c5b44ed3ff9477db7f745a933bb"
# 星火认知大模型Spark Max的domain值，其他版本大模型domain值请前往文档（https://www.xfyun.cn/doc/spark/Web.html）查看
SPARKAI_DOMAIN = "patchv3"

spark = ChatSparkLLM(
    spark_api_url=SPARKAI_URL,
    spark_app_id=SPARKAI_APP_ID,
    spark_api_key=SPARKAI_API_KEY,
    spark_api_secret=SPARKAI_API_SECRET,
    spark_llm_domain=SPARKAI_DOMAIN,
    streaming=False,
    top_k=1,
    max_tokens=8192,
    request_timeout=180,
)
spark.temperature = 0.1

In [3]:
def newHistory(systemContent=None):
    if systemContent is None:
        return []
    return [buildMessage(systemContent, role="system")]

In [4]:
def buildMessage(content, role="user"):
    return ChatMessage(content=content, role=role)

In [5]:
def chat(userContent, history=None):
    if history is None:
        history = newHistory()
    history.append(buildMessage(userContent, role="user"))
    handler = ChunkPrintHandler()
    assistantContent = (
        spark.generate([history], callbacks=[handler]).generations[0][0].text
    )
    history.append(buildMessage(assistantContent, role="assistant"))
    return (assistantContent, history)

In [7]:



if __name__ == "__main__":
    # 从文件中读取内容
    try:
        with open("system_03.md", "r", encoding="utf-8") as f:
            systemContent = f.read()
    except FileNotFoundError:
        print("系统文件未找到")
        exit()

    # 从json文件中读取内容
    try:
        with open("dataset/test_data.json", "r", encoding="utf-8") as f:
            trainData = json.load(f)
    except FileNotFoundError:
        print("数据文件未找到")
        exit()

    final = []
    index = 1
    for item in trainData:
        answer = {"index": index}
        print("index: ", index)
        print("request len: ", len(systemContent) + len(item["chat_text"]))
        answer["infos"] = []
        try:
            if len(systemContent) + len(item["chat_text"]) < 8192:
                resp, history = chat(
                    userContent=item["chat_text"], history=newHistory(systemContent)
                )
                # 使用正则表达式去除前缀和后缀
                cleaned_resp = re.sub(r'^```json\n|```$', '', resp, flags=re.MULTILINE)
                #将单引号替换成双引号
                cleaned_resp = cleaned_resp.replace("'", '"')
                answer["infos"] = json.loads(cleaned_resp)
            else:
                # 对于长文本分块处理，最后进行merge
                userContentBlocks = []
                for i in range(0, len(item["chat_text"]), 6600):
                    if i + 6600 > len(item["chat_text"]):
                        userContentBlocks.append(item["chat_text"][i:])
                
                    elif i == 0:
                        userContentBlocks.append(item["chat_text"][i : i + 6600])
                    else:
                        userContentBlocks.append(item["chat_text"][i - 600 : i + 6600])
                #切分块完成之后，循环调用大模型
                for block in userContentBlocks:
                    #print("userContentBlocks:",block)
                
                    resp, history = chat(userContent=block, history=newHistory(systemContent))
                # 使用正则表达式去除前缀和后缀
                    cleaned_resp = re.sub(r'^```json\n|```$', '', resp, flags=re.MULTILINE)
                #将单引号替换成双引号
                    cleaned_resp = cleaned_resp.replace("'", '"')  
                    #print("resp_xunhuan:",cleaned_resp)
                 #转json格式后提取字典第一个元素，然后循环累加到list中，得到最终的list列表，每个元素是结果字典的形式   
                    answer["infos"].append(json.loads(cleaned_resp)[0])
                    #answer["infos"] = [json.loads(cleaned_resp)[0]] + answer["infos"] 
        except Exception as e:  # 可以根据需要修改异常类型
            print(f"处理过程中出现错误: {e}")
            answer["infos"] = []
        index += 1
        final.append(answer)

    # 写入JSON文件，确保非ASCII字符正确写入
    try:
        with open("final_retry_long02.json", "w", encoding="utf-8") as f:
            f.write(json.dumps(final,indent=4, ensure_ascii=False))
    except IOError:
        print("写入文件时发生错误")

index:  1
request len:  11795
index:  2
request len:  2019
index:  3
request len:  2510
index:  4
request len:  11107
index:  5
request len:  5902
index:  6
request len:  2854
index:  7
request len:  17736
index:  8
request len:  2481
index:  9
request len:  2639
index:  10
request len:  9114
index:  11
request len:  13210
index:  12
request len:  5472
index:  13
request len:  21004
index:  14
request len:  3466
index:  15
request len:  1836
index:  16
request len:  1950
index:  17
request len:  4302
index:  18
request len:  6994
index:  19
request len:  8246
index:  20
request len:  4139
index:  21
request len:  2018
index:  22
request len:  13496
index:  23
request len:  3038
index:  24
request len:  17600
index:  25
request len:  6261
index:  26
request len:  4804
index:  27
request len:  4464
index:  28
request len:  1730
index:  29
request len:  3787
index:  30
request len:  16092
index:  31
request len:  10047
index:  32
request len:  8716
index:  33
request len:  31168
index:  3