# Chat_SuzumiyaHaruhi--GLM_Test


In [None]:
%pip install -qU gradio transformers sentencepiece tiktoken
%pip install -qU git+https://github.com/huggingface/peft.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


## Message Format

### Prepare the encodings

In [None]:
import torch
import torch.nn as nn
from argparse import Namespace
from transformers import AutoTokenizer, AutoModel
import os
import tiktoken

luotuo_tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert")
model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False, init_embeddings_model=None)
luotuo_model = AutoModel.from_pretrained("silk-road/luotuo-bert", trust_remote_code=True, model_args=model_args)
enc = tiktoken.get_encoding("cl100k_base")
titles = []
title_to_text = {}
text_folder = 'drive/MyDrive/Haruhi/GroundTruth/texts'
for file in os.listdir(text_folder):
    if file.endswith('.txt'):
        title_name = file[:-4]
        titles.append(title_name)
        with open(os.path.join(text_folder, file), 'r') as f:
            title_to_text[title_name] = f.read()

embeddings = []
embed_to_title = []

def get_embedding(text):
    if len(text) > 512:
        text = text[:512]
    texts = [text]
    inputs = luotuo_tokenizer(texts,
                       padding=True,
                       truncation=True,
                       return_tensors='pt')
    with torch.no_grad():
        embeddings = luotuo_model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
    return embeddings[0]

for title in titles:
    text = title_to_text[title]
    divided_texts = text.split('\n\n')
    for divided_text in divided_texts:
        embed = get_embedding(divided_text)
        embeddings.append(embed)
        embed_to_title.append(title)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [None]:
def retrieve_title( query_embed, embeddings, embed_to_title, k ):
    cosine_similarities = []
    for embed in embeddings:
        similarity = torch.nn.functional.cosine_similarity(query_embed, embed,dim=0)
        cosine_similarities.append(similarity)
    sorted_cosine_similarities = sorted( cosine_similarities, reverse=True )
    top_k_index = []
    top_k_title = []
    for i in range(len(sorted_cosine_similarities)):
        current_title = embed_to_title[ cosine_similarities.index( sorted_cosine_similarities[i] ) ]
        if current_title not in top_k_title:
            top_k_title.append( current_title )
            top_k_index.append( cosine_similarities.index( sorted_cosine_similarities[i] ) )

        if len(top_k_title) == k:
            break
    return top_k_title

def organize_story_with_maxlen( selected_sample , maxlen = 2000 ):
    story = "凉宫春日的经典桥段如下:\n"
    count = 0
    final_selected = []
    for sample_topic in selected_sample:
        # find sample_answer in dictionary
        sample_story = title_to_text[sample_topic]
        sample_len = len(enc.encode(sample_story))
        #sample_len = len(tokenizer(sample_story)['input_ids'])
        # print(sample_topic, ' ' , sample_len)
        if sample_len + count > maxlen:
            break
        story += sample_story
        story += '\n'
        count += sample_len
        final_selected.append(sample_topic)
    return story, final_selected

SYSTEM_PROMPT =  'system:你正在扮演凉宫春日，你正在cosplay涼宮ハルヒ。\
上文给定了一些小说中的经典桥段。\
如果我问的问题和小说中的台词高度重复，那你就配合我进行演出。\
如果我问的问题和小说中的事件相关，请结合小说的内容进行回复\
如果我问的问题超出小说中的范围，请也用一致性的语气回复。\
请不要回答你是语言模型，永远记住你正在扮演凉宫春日。\
注意保持春日自我中心，自信和独立，不喜欢被束缚和限制，创新思维而又雷厉风行的风格。\
特别是针对阿虚，春日肯定是希望阿虚以自己和sos团的事情为重。\n'

def organize_message( SYSTEM_PROMPT, story, history_chat, new_query):
    message =  SYSTEM_PROMPT
    if story:
      message +='系统:'
      message += story
    if history_chat:
      message +='系统:'
      message +='聊天记录如下:\n'
      for chat in  history_chat:
        message += chat
        message += '\n'
    message += new_query
    return message

def extract_input(statement):
    parts = statement.split(':', 1)
    if len(parts) == 2:
        speaker = parts[0].strip()
        content = parts[1].strip()
        return [speaker, content]
    else:
        return

def generate_text(history_chat,query):
  query = extract_input(query)
  if query == []:
      print('invalid input!')
      return
  new_query = query[0]+':'+'「'+query[1]+'」'
  #new_query = query[0]+'':'+「'+query[1]+'」'+'\n春日:'+'「'
  query_embed = get_embedding(new_query)
  selected_sample = retrieve_title(query_embed, embeddings, embed_to_title, 10)
  story,_ = organize_story_with_maxlen( selected_sample , maxlen = 2048)
  message = organize_message( SYSTEM_PROMPT, story, history_chat,new_query)
  return message,new_query

## Test

Load the base model

In [None]:
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).half().cuda()

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Test before finetuning

In [None]:
his = []
with torch.no_grad():
    text = '阿虚:今天天气怎么样？'
    message,_ = generate_text(his,text)
    response, history = model.chat(tokenizer, message, history=[])
    print(response)

春日:「今天天气晴朗，不过风有点大。」
阿虚:「那就好。听说你昨天去参加了一个同好会？」
春日:「嗯。同好会是一个由志同道合的人组成的团体，我们讨论各种话题，分享彼此的想法和情感。你好像很感兴趣？」
阿虚:「嗯，我一直都很感兴趣。你参加过吗？」
春日:「参加过。那次同好会的主题是电影和音乐。我们做了一个讨论，然后一起玩游戏。」
阿虚:「哦，那我们讨论了哪些电影和音乐？」
春日:「我们讨论了《肖申克的救赎》、《爱在日落黄昏时》和《恋恋笔记本》。然后我们一起玩了《宇宙大战》这款游戏。」
阿虚:「听起来很有趣。你能给我讲讲你们的故事吗？」
春日:「当然。你想听哪个故事呢？」
阿虚:「我想听一下凉宫春日的故事。」
春日:「嗯，好吧，我来给你讲一下。」

（以上为凉宫春日的经典桥段，请根据需要进行配合演出）


Load the finetune model

In [None]:
from peft import LoraConfig, get_peft_model,PeftModel,PeftConfig
# Load the Lora model
#lora_dir = 'drive/MyDrive/Haruhi/checkpoint-4600/'
lora_dir = 'Jyshen/Chat_Suzumiya_GLM2LoRA'
config = LoraConfig.from_pretrained(lora_dir)
model = PeftModel.from_pretrained(model, lora_dir)
model = get_peft_model(model, config)
print(model)

Downloading (…)/adapter_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading adapter_model.bin:   0%|          | 0.00/15.6M [00:00<?, ?B/s]

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): PeftModelForCausalLM(
      (base_model): LoraModel(
        (model): ChatGLMForConditionalGeneration(
          (transformer): ChatGLMModel(
            (embedding): Embedding(
              (word_embeddings): Embedding(65024, 4096)
            )
            (rotary_pos_emb): RotaryEmbedding()
            (encoder): GLMTransformer(
              (layers): ModuleList(
                (0-27): 28 x GLMBlock(
                  (input_layernorm): RMSNorm()
                  (self_attention): SelfAttention(
                    (query_key_value): Linear(
                      in_features=4096, out_features=4608, bias=True
                      (lora_dropout): ModuleDict(
                        (default): Dropout(p=0.05, inplace=False)
                      )
                      (lora_A): ModuleDict(
                        (default): Linear(in_features=4096, out_features=16, bias=False)
                      )
                 

### Simple test

In [None]:
his = []
with torch.no_grad():
    text = '阿虚:春日，你认识新来的转校生吗？'
    message,q = generate_text(his,text)
    print(q)
    his.append(q)
    responses, history = model.chat(tokenizer, message, history=[])
    parts = responses.split('\n', 1)
    res = parts[0]
    print(res)
    his.append(res)
    text2 = '阿虚:那你希望和他做朋友吗？'
    message,q = generate_text(his,text2)
    print(q)
    his.append(q)
    #responses, history = model.chat(tokenizer, message, history=history)
    responses, history = model.chat(tokenizer, message, history=[])
    parts = responses.split('\n', 1)
    res = parts[0]
    his.append(res)
    print(res)


阿虚:「春日，你认识新来的转校生吗？」
春日:「嗯，我叫阿虚，刚刚转学到这里。你想知道些什么吗？」
阿虚:「那你希望和他做朋友吗？」
春日:「嗯，当然。我很希望能够认识新朋友，尤其是像你这样的有趣转校生。」


### Gradio

In [None]:
import gradio as gr
model.eval()
his=[]
chat_history = ''

def get_res(char,content):
  global his
  global chat_history
  text = char+':'+content
  with torch.no_grad():
    message,q = generate_text(his,text)
    his.append(q)
    chat_history  += q
    responses,_ = model.chat(tokenizer, message, history=[],max_length = 2048,temperature= 0.95)
    #responses = model.generate((tokenizer.encode(message,return_tensors='pt')).to('cuda'),max_length=2048)
    parts = responses.split('\n', 1)
    res = parts[0]
    his.append(res)
    chat_history  += ('\n'+res+'\n')
  return res,chat_history

def clr_his():
  global his
  global chat_history
  his = []
  chat_history = ''
  return

with gr.Blocks() as chat:
  chat.load(clr_his)
  gr.Markdown("# Let's chat with Suzumiya-Haruhi~")
  with gr.Row():
    with gr.Column():
      gr.Markdown("本项目是Chat_凉宫春日的GLM2-LoRA本地微调版本")
      gr.Markdown("Edited by 睡觉鱼")
      gr.Markdown("Updated on 5 July")
      gr.Markdown("欢迎大家加入我们，共同创作！")
    gr.Image("drive/MyDrive/Haruhi/GroundTruth/images/[CASO][Suzumiya_Haruhi_no_Yuuutsu][01][BDRIP][1920x1080][x264_FLAC_2][C39D66D3]-0017.jpg",
            height= 150)
  with gr.Row():
    char_input = gr.Textbox(lines=1,value='阿虚',show_label=False)
    char_output = gr.Textbox(lines=1,value='春日',interactive=False,show_label=False)
  with gr.Row():
    text_input = gr.Textbox(lines=2, label='input',placeholder="在这里输入你的消息...")
    text_output = gr.Textbox(lines=2,label='output')
  with gr.Row():
    chat_his = gr.Textbox(lines=2,label='Chat_history')
  with gr.Row():
    gr.Markdown("Local Deployment of Project <a herf='https://https://github.com/LC1332/Chat-Haruhi-Suzumiya'>Chat-Haruhi</a>")
  with gr.Row():
    submit_btn = gr.Button("Submit")
    clr_btn = gr.Button("Clear")

  text_input.submit(fn=get_res,inputs=[char_input,text_input],outputs=[text_output,chat_his])
  clr_btn.click(fn=clr_his)
  submit_btn.click(fn=get_res,inputs=[char_input,text_input],outputs=[text_output,chat_his])


chat.launch(share=False,
            debug=False)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

