# web_demo_gradio.py

# 整合前测试

In [2]:
pip install gradio

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Collecting gradio
  Downloading http://mirrors.aliyun.com/pypi/packages/a7/06/29af2174e7aa1f3890e46483902bbb10628a90ddc053b2b02fcd201da89e/gradio-4.25.0-py3-none-any.whl (17.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting python-multipart>=0.0.9
  Downloading http://mirrors.aliyun.com/pypi/packages/3d/47/444768600d9e0ebc82f8e347775d24aef8f6348cf00e9fa0e81910814e6d/python_multipart-0.0.9-py3-none-any.whl (22 kB)
Collecting ruff>=0.2.2
  Downloading http://mirrors.aliyun.com/pypi/packages/5a/8d/cc5d7078ffd2779a38a05bcff8ffaa24f524de51c46223a2dd2710a8c50e/ruff-0.3.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.7/8.7 MB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting orjson~=3.0
  Downloading http://mirrors.aliyun

In [1]:
import os
import gradio as gr
import torch
from threading import Thread

from typing import Union, Annotated
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer
)

In [2]:
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

In [3]:
MODEL_PATH = '/root/.cache/modelscope/hub/ZhipuAI/chatglm3-6b'
TOKENIZER_PATH = '/root/.cache/modelscope/hub/ZhipuAI/chatglm3-6b'

In [4]:
def _resolve_path(path: Union[str, Path]) -> Path:
    return Path(path).expanduser().resolve()

In [5]:
def load_model_and_tokenizer(
        model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
    model_dir = _resolve_path(model_dir)
    if (model_dir / 'adapter_config.json').exists():
        model = AutoPeftModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model.peft_config['default'].base_model_name_or_path
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model_dir
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_dir, trust_remote_code=trust_remote_code
    )
    return model, tokenizer

In [6]:
model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


In [7]:
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [0, 2]
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

In [8]:
def parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text


In [9]:
def predict(history, max_length, top_p, temperature):
    stop = StopOnTokens()
    messages = []
    for idx, (user_msg, model_msg) in enumerate(history):
        if idx == len(history) - 1 and not model_msg:
            messages.append({"role": "user", "content": user_msg})
            break
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if model_msg:
            messages.append({"role": "assistant", "content": model_msg})

    print("\n\n====conversation====\n", messages)
    model_inputs = tokenizer.apply_chat_template(messages,
                                                 add_generation_prompt=True,
                                                 tokenize=True,
                                                 return_tensors="pt").to(next(model.parameters()).device)
    streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "input_ids": model_inputs,
        "streamer": streamer,
        "max_new_tokens": max_length,
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "stopping_criteria": StoppingCriteriaList([stop]),
        "repetition_penalty": 1.2,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    for new_token in streamer:
        if new_token != '':
            history[-1][1] += new_token
            yield history


In [10]:
with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">ChatGLM3-6B Gradio Simple Demo</h1>""")
    chatbot = gr.Chatbot()

    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)


    def user(query, history):
        return "", history + [[parse_text(query), ""]]


    submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
        predict, [chatbot, max_length, top_p, temperature], chatbot
    )
    emptyBtn.click(lambda: None, None, chatbot, queue=False)

In [11]:
demo.queue()

Gradio Blocks instance: 3 backend functions
-------------------------------------------
fn_index=0
 inputs:
 |-<gradio.components.textbox.Textbox object at 0x7fe6006160b0>
 |-<gradio.components.chatbot.Chatbot object at 0x7fe600658160>
 outputs:
 |-<gradio.components.textbox.Textbox object at 0x7fe6006160b0>
 |-<gradio.components.chatbot.Chatbot object at 0x7fe600658160>
fn_index=1
 inputs:
 |-<gradio.components.chatbot.Chatbot object at 0x7fe600658160>
 |-<gradio.components.slider.Slider object at 0x7fe600617a90>
 |-<gradio.components.slider.Slider object at 0x7fe600616ec0>
 |-<gradio.components.slider.Slider object at 0x7fe600616ad0>
 outputs:
 |-<gradio.components.chatbot.Chatbot object at 0x7fe600658160>
fn_index=2
 inputs:
 outputs:
 |-<gradio.components.chatbot.Chatbot object at 0x7fe600658160>

In [12]:
demo.launch(server_name="127.0.0.1", server_port=7870, inbrowser=True, share=True)

Running on local URL:  http://127.0.0.1:7870

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.





No chat template is defined for this tokenizer - using a default chat template that implements the ChatML format (without BOS/EOS tokens!). If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.

Exception in thread Thread-12 (generate):
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/root/miniconda3/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/miniconda3/lib/python3.10/site-packages/transformers/generation/utils.py", line 1575, in generate
    result = self._sample(
  File "/root/miniconda3/lib/python3.10/site-packages/transformers/generat



====conversation====
 [{'role': 'user', 'content': '你好'}]


Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.10/site-packages/gradio/queueing.py", line 522, in process_events
    response = await route_utils.call_process_api(
  File "/root/miniconda3/lib/python3.10/site-packages/gradio/route_utils.py", line 260, in call_process_api
    output = await app.get_blocks().process_api(
  File "/root/miniconda3/lib/python3.10/site-packages/gradio/blocks.py", line 1741, in process_api
    result = await self.call_function(
  File "/root/miniconda3/lib/python3.10/site-packages/gradio/blocks.py", line 1308, in call_function
    prediction = await utils.async_iteration(iterator)
  File "/root/miniconda3/lib/python3.10/site-packages/gradio/utils.py", line 575, in async_iteration
    return await iterator.__anext__()
  File "/root/miniconda3/lib/python3.10/site-packages/gradio/utils.py", line 568, in __anext__
    return await anyio.to_thread.run_sync(
  File "/root/miniconda3/lib/python3.10/site-packages/anyio/to_thread.py", line 56,

In [None]:
!wget "https://cdn-media.huggingface.co/frpc-gradio-0.2/frpc_linux_amd64" -O /root/miniconda3/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.2

--2024-04-03 10:32:47--  https://cdn-media.huggingface.co/frpc-gradio-0.2/frpc_linux_amd64
Resolving cdn-media.huggingface.co (cdn-media.huggingface.co)... 31.13.76.65, 2600:9000:2363:b400:19:6fb8:2ac0:93a1, 2600:9000:2363:7a00:19:6fb8:2ac0:93a1, ...
Connecting to cdn-media.huggingface.co (cdn-media.huggingface.co)|31.13.76.65|:443... 

In [None]:
# 下载失败，手动上传

In [1]:
ll -h frpc_linux_amd64

-rw-r--r-- 1 root 1.0M Apr  3 10:34 frpc_linux_amd64


In [2]:
mv frpc_linux_amd64 frpc_linux_amd64_v0.2

In [3]:
mv frpc_linux_amd64_v0.2 /root/miniconda3/lib/python3.10/site-packages/gradio

# 整合后的web_demo_gradio.py

In [1]:
import os
import gradio as gr
import torch
from threading import Thread

from typing import Union, Annotated
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer
)

ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

# MODEL_PATH = '/root/.cache/modelscope/hub/ZhipuAI/chatglm3-6b'
# TOKENIZER_PATH = '/root/.cache/modelscope/hub/ZhipuAI/chatglm3-6b'

def _resolve_path(path: Union[str, Path]) -> Path:
    return Path(path).expanduser().resolve()


def load_model_and_tokenizer(
        model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
    model_dir = _resolve_path(model_dir)
    if (model_dir / 'adapter_config.json').exists():
        model = AutoPeftModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model.peft_config['default'].base_model_name_or_path
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model_dir
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_dir, trust_remote_code=trust_remote_code
    )
    return model, tokenizer

model_dir = 'output/checkpoint-2000' # 输入微调后的Checkpoint目录地址
model, tokenizer = load_model_and_tokenizer(model_dir, trust_remote_code=True)


class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [0, 2]
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


def parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text


def predict(history, max_length, top_p, temperature):
    stop = StopOnTokens()
    messages = []
    for idx, (user_msg, model_msg) in enumerate(history):
        if idx == len(history) - 1 and not model_msg:
            messages.append({"role": "user", "content": user_msg})
            break
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if model_msg:
            messages.append({"role": "assistant", "content": model_msg})

    print("\n\n====conversation====\n", messages)
    model_inputs = tokenizer.apply_chat_template(messages,
                                                 add_generation_prompt=True,
                                                 tokenize=True,
                                                 return_tensors="pt").to(next(model.parameters()).device)
    streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "input_ids": model_inputs,
        "streamer": streamer,
        "max_new_tokens": max_length,
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "stopping_criteria": StoppingCriteriaList([stop]),
        "repetition_penalty": 1.2,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    for new_token in streamer:
        if new_token != '':
            history[-1][1] += new_token
            yield history



with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">ChatGLM3-6B Gradio Simple Demo</h1>""")
    chatbot = gr.Chatbot()

    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)


    def user(query, history):
        return "", history + [[parse_text(query), ""]]


    submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
        predict, [chatbot, max_length, top_p, temperature], chatbot
    )
    emptyBtn.click(lambda: None, None, chatbot, queue=False)

demo.queue()
demo.launch(server_name="127.0.0.1", server_port=7870, inbrowser=True, share=True)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


Running on local URL:  http://127.0.0.1:7870

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.





No chat template is defined for this tokenizer - using a default chat template that implements the ChatML format (without BOS/EOS tokens!). If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.





====conversation====
 [{'role': 'user', 'content': '你好'}]


====conversation====
 [{'role': 'user', 'content': '你好'}, {'role': 'assistant', 'content': '你好，我是人工智能助手。有什么我可以帮助你的吗？'}, {'role': 'user', 'content': '类型#裙*裙长#半身裙'}]
