Skip to content

Commit

Permalink
加入中止回答的功能
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Mar 23, 2023
1 parent 3fe8fc4 commit 2c5812c
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 74 deletions.
101 changes: 58 additions & 43 deletions ChuanhuChatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import gradio as gr

from utils import *
from presets import *
from overwrites import *
from chat_func import *
from modules.utils import *
from modules.presets import *
from modules.overwrites import *
from modules.chat_func import *

logging.basicConfig(
level=logging.DEBUG,
Expand Down Expand Up @@ -54,7 +54,7 @@
gr.Chatbot.postprocess = postprocess
PromptHelper.compact_text_chunks = compact_text_chunks

with open("custom.css", "r", encoding="utf-8") as f:
with open("assets/custom.css", "r", encoding="utf-8") as f:
customCSS = f.read()

with gr.Blocks(
Expand Down Expand Up @@ -124,8 +124,7 @@
token_count = gr.State([])
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
user_api_key = gr.State(my_api_key)
TRUECOMSTANT = gr.State(True)
FALSECONSTANT = gr.State(False)
outputing = gr.State(False)
topic = gr.State("未命名对话历史记录")

with gr.Row():
Expand Down Expand Up @@ -275,12 +274,9 @@

gr.Markdown(description)

keyTxt.submit(submit_key, keyTxt, [user_api_key, status_display])
keyTxt.change(submit_key, keyTxt, [user_api_key, status_display])
# Chatbot
user_input.submit(
predict,
[
chatgpt_predict_args = dict(
fn=predict,
inputs=[
user_api_key,
systemPromptTxt,
history,
Expand All @@ -294,40 +290,45 @@
use_websearch_checkbox,
index_files,
],
[chatbot, history, status_display, token_count],
outputs=[chatbot, history, status_display, token_count],
show_progress=True,
)
user_input.submit(reset_textbox, [], [user_input])

# submitBtn.click(return_cancel_btn, [], [submitBtn, cancelBtn])
submitBtn.click(
predict,
[
user_api_key,
systemPromptTxt,
history,
user_input,
chatbot,
token_count,
top_p,
temperature,
use_streaming_checkbox,
model_select_dropdown,
use_websearch_checkbox,
index_files,
],
[chatbot, history, status_display, token_count],
show_progress=True,
start_outputing_args = dict(
fn=start_outputing, inputs=[], outputs=[submitBtn, cancelBtn], show_progress=True
)

end_outputing_args = dict(
fn=end_outputing, inputs=[], outputs=[submitBtn, cancelBtn]
)

reset_textbox_args = dict(
fn=reset_textbox, inputs=[], outputs=[user_input], show_progress=True
)

keyTxt.submit(submit_key, keyTxt, [user_api_key, status_display])
keyTxt.change(submit_key, keyTxt, [user_api_key, status_display])
# Chatbot
cancelBtn.click(cancel_outputing, [], [])

user_input.submit(**start_outputing_args).then(
**chatgpt_predict_args
).then(**reset_textbox_args).then(
**end_outputing_args
)
submitBtn.click(**start_outputing_args).then(
**chatgpt_predict_args
).then(**reset_textbox_args).then(
**end_outputing_args
)
submitBtn.click(reset_textbox, [], [user_input])

emptyBtn.click(
reset_state,
outputs=[chatbot, history, token_count, status_display],
show_progress=True,
)
).then(**reset_textbox_args)

retryBtn.click(
retryBtn.click(**start_outputing_args).then(
retry,
[
user_api_key,
Expand All @@ -342,7 +343,7 @@
],
[chatbot, history, status_display, token_count],
show_progress=True,
)
).then(**end_outputing_args)

delLastBtn.click(
delete_last_conversation,
Expand Down Expand Up @@ -441,17 +442,31 @@
if dockerflag:
if authflag:
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
server_name="0.0.0.0", server_port=7860, auth=(username, password),
favicon_path="./assets/favicon.png"
server_name="0.0.0.0",
server_port=7860,
auth=(username, password),
favicon_path="./assets/favicon.png",
)
else:
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False, favicon_path="./assets/favicon.png")
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
favicon_path="./assets/favicon.png",
)
# if not running in Docker
else:
if authflag:
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(share=False, auth=(username, password), favicon_path="./assets/favicon.png", inbrowser=True)
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
share=False,
auth=(username, password),
favicon_path="./assets/favicon.png",
inbrowser=True,
)
else:
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(share=False, favicon_path="./assets/favicon.ico", inbrowser=True) # 改为 share=True 可以创建公开分享链接
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
share=False, favicon_path="./assets/favicon.ico", inbrowser=True
) # 改为 share=True 可以创建公开分享链接
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
File renamed without changes.
34 changes: 22 additions & 12 deletions chat_func.py → modules/chat_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import asyncio
import aiohttp

from presets import *
from llama_func import *
from utils import *
from modules.presets import *
from modules.llama_func import *
from modules.utils import *
import modules.shared as shared

# logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")

Expand All @@ -29,7 +30,6 @@ class DataframeData(TypedDict):


initial_prompt = "You are a helpful assistant."
API_URL = "https://api.openai.com/v1/chat/completions"
HISTORY_DIR = "history"
TEMPLATES_DIR = "templates"

Expand Down Expand Up @@ -65,16 +65,18 @@ def get_response(
# 如果存在代理设置,使用它们
proxies = {}
if http_proxy:
logging.info(f"Using HTTP proxy: {http_proxy}")
logging.info(f"使用 HTTP 代理: {http_proxy}")
proxies["http"] = http_proxy
if https_proxy:
logging.info(f"Using HTTPS proxy: {https_proxy}")
logging.info(f"使用 HTTPS 代理: {https_proxy}")
proxies["https"] = https_proxy

# 如果有代理,使用代理发送请求,否则使用默认设置发送请求
if shared.state.api_url != API_URL:
logging.info(f"使用自定义API URL: {shared.state.api_url}")
if proxies:
response = requests.post(
API_URL,
shared.state.api_url,
headers=headers,
json=payload,
stream=True,
Expand All @@ -83,7 +85,7 @@ def get_response(
)
else:
response = requests.post(
API_URL,
shared.state.api_url,
headers=headers,
json=payload,
stream=True,
Expand Down Expand Up @@ -268,10 +270,10 @@ def predict(
if files:
msg = "构建索引中……(这可能需要比较久的时间)"
logging.info(msg)
yield chatbot, history, msg, all_token_counts
yield chatbot+[(inputs, "")], history, msg, all_token_counts
index = construct_index(openai_api_key, file_src=files)
msg = "索引构建完成,获取回答中……"
yield chatbot, history, msg, all_token_counts
yield chatbot+[(inputs, "")], history, msg, all_token_counts
history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot)
yield chatbot, history, status_text, all_token_counts
return
Expand Down Expand Up @@ -306,10 +308,15 @@ def predict(
all_token_counts.append(0)
else:
history[-2] = construct_user(inputs)
yield chatbot, history, status_text, all_token_counts
yield chatbot+[(inputs, "")], history, status_text, all_token_counts
return
elif len(inputs.strip()) == 0:
status_text = standard_error_msg + no_input_msg
logging.info(status_text)
yield chatbot+[(inputs, "")], history, status_text, all_token_counts
return

yield chatbot, history, "开始生成回答……", all_token_counts
yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts

if stream:
logging.info("使用流式传输")
Expand All @@ -327,6 +334,9 @@ def predict(
display_append=link_references
)
for chatbot, history, status_text, all_token_counts in iter:
if shared.state.interrupted:
shared.state.recover()
return
yield chatbot, history, status_text, all_token_counts
else:
logging.info("不使用流式传输")
Expand Down
4 changes: 2 additions & 2 deletions llama_func.py → modules/llama_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import colorama


from presets import *
from utils import *
from modules.presets import *
from modules.utils import *


def get_documents(file_src):
Expand Down
6 changes: 3 additions & 3 deletions overwrites.py → modules/overwrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import List, Tuple
import mdtex2html

from presets import *
from llama_func import *
from modules.presets import *
from modules.llama_func import *


def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
Expand Down Expand Up @@ -51,5 +51,5 @@ def template_response(*args, **kwargs):
return res

gr.routes.templates.TemplateResponse = template_response

GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
3 changes: 2 additions & 1 deletion presets.py → modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
no_input_msg = "请输入对话内容。" # 未输入对话内容

max_token_streaming = 3500 # 流式对话时的最大 token 数
timeout_streaming = 30 # 流式对话时的超时时间
timeout_streaming = 5 # 流式对话时的超时时间
max_token_all = 3500 # 非流式对话时的最大 token 数
timeout_all = 200 # 非流式对话时的超时时间
enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
Expand Down
24 changes: 24 additions & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from modules.presets import API_URL

class State:
interrupted = False
api_url = API_URL

def interrupt(self):
self.interrupted = True

def recover(self):
self.interrupted = False

def set_api_url(self, api_url):
self.api_url = api_url

def reset_api_url(self):
self.api_url = API_URL
return self.api_url

def reset_all(self):
self.interrupted = False
self.api_url = API_URL

state = State()

0 comments on commit 2c5812c

Please sign in to comment.