In [1]:
class CFG:
    model_path = '/root/autodl-tmp/weights/chatglm3-6b'
    data_path = '/root/autodl-tmp/dataset/psychology-dataset/data/train.jsonl'
    lora_dir = '/root/autodl-tmp/checkpoints/previous/glm3-Rank64_27000' #'/root/autodl-tmp/checkpoints/glm3-3-dataset-Rank64'
    MAX_TURNS = 20

In [2]:
import os
import sys
import mdtex2html
import gradio as gr
from transformers import AutoModel, AutoTokenizer
sys.path.append('/root/tuning_space/Components/')
import model_tools



In [3]:
import os
import gradio as gr
import torch
from threading import Thread
from typing import Union
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer
)

In [4]:
ls $CFG.model_path

MODEL_LICENSE                     pytorch_model-00005-of-00007.bin
README.md                         pytorch_model-00006-of-00007.bin
config.json                       pytorch_model-00007-of-00007.bin
configuration_chatglm.py          pytorch_model.bin.index.json
modeling_chatglm.py               quantization.py
pytorch_model-00001-of-00007.bin  tokenization_chatglm.py
pytorch_model-00002-of-00007.bin  tokenizer.model
pytorch_model-00003-of-00007.bin  tokenizer_config.json
pytorch_model-00004-of-00007.bin


In [5]:
tokenizer = AutoTokenizer.from_pretrained(CFG.model_path, trust_remote_code=True)
#tokenization_chatglm.py
#tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
#model = model_tools.merge_lora(CFG.model_path, CFG.lora_dir)
model = AutoModel.from_pretrained(CFG.model_path, trust_remote_code=True).cuda().half()
model = model.eval()

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


In [16]:
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [0, 2]#[0, 2, 31002, 64795]
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


def parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text


def predict(history, max_length, repetition_penalty, temperature):
    print(1)
    stop = StopOnTokens()
    messages = []
    for idx, (user_msg, model_msg) in enumerate(history):
        if idx == len(history) - 1 and not model_msg:
            messages.append({"role": "user", "content": user_msg})
            break
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if model_msg:
            messages.append({"role": "assistant", "content": model_msg})

    print("\n\n====conversation====\n", messages)
    print(messages)
    model_inputs = tokenizer.apply_chat_template(messages,
                                                 add_generation_prompt=False,
                                                 tokenize=True,
                                                 return_tensors="pt").to(next(model.parameters()).device)
    print(model_inputs)
    streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "input_ids": model_inputs,
        "streamer": streamer,
        "max_new_tokens": max_length,
        "do_sample": True,
        "temperature": temperature,
        "stopping_criteria": StoppingCriteriaList([stop]),
        "repetition_penalty": repetition_penalty,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    for new_token in streamer:
        if new_token != '':
            history[-1][1] += new_token
            yield history

In [17]:
with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">心理对话微调 ChatGLM3-6B Gradio 简单 Demo</h1>""")
    chatbot = gr.Chatbot()

    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="输入...", lines=10, container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("清除上下文")
            max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
            temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
            repetition_penalty = gr.Slider(1.0, 1.5, value=1.25, step=0.01, label="repetition_penalty", interactive=True)


    def user(query, history):
        return "", history + [[parse_text(query), ""]]


    submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
        predict, [chatbot, max_length, repetition_penalty, temperature], chatbot
    )
    emptyBtn.click(lambda: None, None, chatbot, queue=False)

demo.queue()
demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7861
Running on public URL: https://233c105eb9e5ade4b3.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




1


====conversation====
 [{'role': 'user', 'content': '当别人不注意我的时候，我就很愤怒，该怎么办？如题。我想做一个不用那么去争取别人注意力，又能被别人看到的人。<br>比如说，开会的时候，大家会讨论，而忽略我；或者新的环境里，有很多小团体，他们也不会主动叫上我。<br>一开始是尴尬，时间长了次数多了，我开始感受到被忽略的愤怒，知道自己不应该这样，可是就是好生气，该怎么办呢？'}]
[{'role': 'user', 'content': '当别人不注意我的时候，我就很愤怒，该怎么办？如题。我想做一个不用那么去争取别人注意力，又能被别人看到的人。<br>比如说，开会的时候，大家会讨论，而忽略我；或者新的环境里，有很多小团体，他们也不会主动叫上我。<br>一开始是尴尬，时间长了次数多了，我开始感受到被忽略的愤怒，知道自己不应该这样，可是就是好生气，该怎么办呢？'}]
tensor([[  906, 31007,   326, 30962,  6631, 31007, 30994,  4865,    13, 54673,
         32282, 54535, 31937, 54546, 31737, 31123, 33600, 54657, 38216, 31123,
         49086, 31514, 54627, 54736, 31155, 33103, 36941, 33033, 31783, 54701,
         34629, 32282, 38038, 31123, 38617, 54732, 32282, 31857, 31635, 31155,
         31002,  1335, 30994, 37514, 31123, 45955, 31737, 31123, 31684, 54549,
         32654, 31123, 54617, 36481, 54546, 54659, 31767, 31888, 31747, 54662,
         31123, 33446, 54603, 34393, 31123, 31633, 34093, 32270, 55483, 54547,
         54546, 31155, 31002,  

In [9]:
tokenizer.decode([  906, 31007,   326, 30962,  6631, 31007, 30994])

'<|im_start|>'

In [10]:
tokenizer.decode([  906, 31007,   326, 30962,  6631, 31007, 30994,  4865,    13, 39701,
         31002, 31007,   326, 30962,   437, 31007, 30994,    13])

'<|im_start|>user\n你好<|im_end|>\n'

In [11]:
tokenizer.decode([  906, 31007,   326, 30962,  6631, 31007, 30994,  4865,    13, 39701,
         31002, 31007,   326, 30962,   437, 31007, 30994,    13, 31002, 31007,
           326, 30962,  6631, 31007, 30994,   530, 18971,    13])

'<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n'

In [12]:
tokenizer.decode([  906, 31007,   326, 30962,  6631, 31007, 30994,  4865,    13, 39701,
         31002, 31007,   326, 30962,   437, 31007, 30994,    13, 31002, 31007,
           326, 30962,  6631, 31007, 30994,   530, 18971,    13, 48214, 31123,
         33030, 34797, 42481, 31155, 42693, 33277, 31639, 40648, 55268, 55353,
         36295, 55398, 31514, 31002, 31007,  4865, 31007,  6144, 31007,   326,
         30962,   437, 31007, 30994,    13, 31002, 31007,   326, 30962,  6631,
         31007, 30994,  4865,    13, 42693, 34607, 55622, 31514, 31002, 31007,
           326, 30962,   437, 31007, 30994,    13, 31002, 31007,   326, 30962,
          6631, 31007, 30994,   530, 18971,    13, 54546, 32103, 34797, 42481,
         31123, 31628, 33287, 55353, 32184, 54542, 31692, 31934, 31155, 31002,
         31007,  4865, 31007,  6144, 31007,   326, 30962,   437, 31007, 30994,
            13, 31002, 31007,   326, 30962,  6631, 31007, 30994,  4865,    13,
         54673, 32282, 54535, 31937, 54546, 31737, 31123, 33600, 54657, 38216,
         31123, 49086, 31514, 54627, 54736, 31155, 33103, 36941, 33033, 31783,
         54701, 34629, 32282, 38038, 31123, 38617, 54732, 32282, 31857, 31635,
         31155, 31002,  1335, 30994, 37514, 31123, 45955, 31737, 31123, 31684,
         54549, 32654, 31123, 54617, 36481, 54546, 54659, 31767, 31888, 31747,
         54662, 31123, 33446, 54603, 34393, 31123, 31633, 34093, 32270, 55483,
         54547, 54546, 31155, 31002,  1335, 30994, 35872, 54532, 35556, 31123,
         31643, 42165, 36942, 33851, 31123, 51809, 33816, 54732, 36481, 54530,
         38216, 31123, 43292, 40919, 31676, 31123, 32435, 31632, 54591, 36443,
         31123, 49086, 55282, 31514, 31002, 31007,   326, 30962,   437, 31007,
         30994,    13, 31002, 31007,   326, 30962,  6631, 31007, 30994,   530,
         18971,    13])

'<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n您好，我是人工智能助手。请问有什么问题我可以帮您解答吗？<|user|><|im_end|>\n<|im_start|>user\n请问你是谁？<|im_end|>\n<|im_start|>assistant\n我是一个人工智能助手，可以回答您的问题和提供帮助。<|user|><|im_end|>\n<|im_start|>user\n当别人不注意我的时候，我就很愤怒，该怎么办？如题。我想做一个不用那么去争取别人注意力，又能被别人看到的人。<br>比如说，开会的时候，大家会讨论，而忽略我；或者新的环境里，有很多小团体，他们也不会主动叫上我。<br>一开始是尴尬，时间长了次数多了，我开始感受到被忽略的愤怒，知道自己不应该这样，可是就是好生气，该怎么办呢？<|im_end|>\n<|im_start|>assistant\n'

In [13]:
tokenizer.decode(31002)

'<'

```
def simple_chat(prompts):
    response, history = model.chat(
        tokenizer, 
        prompts, 
        history=[],
        do_sample=True, 
        temperature=0.3,
        top_p=1,
        repetition_penalty=1.0,)
    return response

demo = gr.Interface(fn=simple_chat, 
                    inputs=gr.Textbox(
                        label='''说点什么~
                        因为这是一个以心理疏导为导向的模型，问一些与心理疏导无关的内容效果可能会变差
                        目前web Demo只支持单论对话（即模型无法了解你之前发送的内容）, 正在改进中...
                        '''), 
                    outputs=gr.Text(
                        label="""回复~
                        目前参数temperature=0.8, repetition_penalty=1.2"""
                    ))

demo.launch(share=True, server_port=6006)
```

In [14]:
def simple_chat(prompts):
    response, history = model.chat(
        tokenizer, 
        prompts, 
        history=[],
        do_sample=True, 
        temperature=0.3,
        top_p=1,
        repetition_penalty=1.0,)
    return response

In [15]:
simple_chat('你好')

'很高兴认识你。请问有什么我可以帮助你的？'