In [1]:
from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html

from typing import Dict, Tuple, Union, Optional
from utils import load_model_on_gpus

In [2]:
model_path = "../chatglm2-6b-model"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

In [4]:
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half()
#print(model)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [5]:
from accelerate import infer_auto_device_map
device_map = infer_auto_device_map(model, max_memory={0: "4GiB", 1: "10GiB", "cpu": "30GiB"}, no_split_module_classes=["GLMBlock"])

In [6]:
print(device_map)

{'transformer.embedding': 0, 'transformer.rotary_pos_emb': 0, 'transformer.encoder.layers.0': 0, 'transformer.encoder.layers.1': 0, 'transformer.encoder.layers.2': 0, 'transformer.encoder.layers.3': 0, 'transformer.encoder.layers.4': 0, 'transformer.encoder.layers.5': 0, 'transformer.encoder.layers.6': 0, 'transformer.encoder.layers.7': 1, 'transformer.encoder.layers.8': 1, 'transformer.encoder.layers.9': 1, 'transformer.encoder.layers.10': 1, 'transformer.encoder.layers.11': 1, 'transformer.encoder.layers.12': 1, 'transformer.encoder.layers.13': 1, 'transformer.encoder.layers.14': 1, 'transformer.encoder.layers.15': 1, 'transformer.encoder.layers.16': 1, 'transformer.encoder.layers.17': 1, 'transformer.encoder.layers.18': 1, 'transformer.encoder.layers.19': 1, 'transformer.encoder.layers.20': 1, 'transformer.encoder.layers.21': 1, 'transformer.encoder.layers.22': 1, 'transformer.encoder.layers.23': 1, 'transformer.encoder.layers.24': 1, 'transformer.encoder.layers.25': 1, 'transformer

#device_map['transformer.final_layernorm']=0
#device_map['lm_head']=0
#print(device_map)

In [7]:
from accelerate import dispatch_model 
model = dispatch_model(model, device_map=device_map)

In [8]:
model = model.eval()

In [9]:
print(model.hf_device_map)

{'transformer.embedding': 0, 'transformer.rotary_pos_emb': 0, 'transformer.encoder.layers.0': 0, 'transformer.encoder.layers.1': 0, 'transformer.encoder.layers.2': 0, 'transformer.encoder.layers.3': 0, 'transformer.encoder.layers.4': 0, 'transformer.encoder.layers.5': 0, 'transformer.encoder.layers.6': 0, 'transformer.encoder.layers.7': 1, 'transformer.encoder.layers.8': 1, 'transformer.encoder.layers.9': 1, 'transformer.encoder.layers.10': 1, 'transformer.encoder.layers.11': 1, 'transformer.encoder.layers.12': 1, 'transformer.encoder.layers.13': 1, 'transformer.encoder.layers.14': 1, 'transformer.encoder.layers.15': 1, 'transformer.encoder.layers.16': 1, 'transformer.encoder.layers.17': 1, 'transformer.encoder.layers.18': 1, 'transformer.encoder.layers.19': 1, 'transformer.encoder.layers.20': 1, 'transformer.encoder.layers.21': 1, 'transformer.encoder.layers.22': 1, 'transformer.encoder.layers.23': 1, 'transformer.encoder.layers.24': 1, 'transformer.encoder.layers.25': 1, 'transformer

In [10]:
def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>"+line
    text = "".join(lines)
    return text


def predict(input, chatbot, max_length, top_p, temperature, history):
    chatbot.append((parse_text(input), ""))
    for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
                                               temperature=temperature):
        chatbot[-1] = (parse_text(input), parse_text(response))       

        yield chatbot, history


def reset_user_input():
    return gr.update(value='')


def reset_state():
    return [], []

In [11]:
with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">ChatGLM</h1>""")

    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
                    container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

    history = gr.State([])

    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
                    show_progress=True)
    submitBtn.click(reset_user_input, [], [user_input])

    emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)

demo.queue().launch(share=True, inbrowser=True)


Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://ac0a87426a4de60d86.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


