From f6a599461fabd462298d4367e94615808da4d88e Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Fri, 17 May 2024 17:07:33 +0800 Subject: [PATCH] fix zhipuAI stream issue (#825) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/api_app.py | 2 +- api/apps/conversation_app.py | 2 +- rag/llm/chat_model.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 0494490419..0c0b191b2b 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -222,7 +222,7 @@ def stream(): resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp else: - ans = chat(dia, msg, False, **req) + ans = chat(dia, msg, **req) fillin_conv(ans) API4ConversationService.append_message(conv.id, conv.to_dict()) return get_json_result(data=ans) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index ed52500441..6d06a05f84 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -162,7 +162,7 @@ def stream(): return resp else: - ans = chat(dia, msg, False, **req) + ans = chat(dia, msg, **req) fillin_conv(ans) ConversationService.update_by_id(conv.id, conv.to_dict()) return get_json_result(data=ans) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 0c53279f69..eac3f8df85 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -193,10 +193,11 @@ def chat_streamly(self, system, history, gen_conf): if not resp.choices[0].delta.content:continue delta = resp.choices[0].delta.content ans += delta - tk_count = resp.usage.total_tokens if response.usage else 0 - if resp.output.choices[0].finish_reason == "length": + if resp.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + tk_count = resp.usage.total_tokens + if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e)