In [1]:
!ls /root/autodl-tmp/checkpoints/

flagged			     glm3-3-dataset-Rank64_32000
glm3-3-dataset-Rank64	     glm3-3-dataset-Rank64_33000
glm3-3-dataset-Rank64_1000   glm3-3-dataset-Rank64_34000
glm3-3-dataset-Rank64_10000  glm3-3-dataset-Rank64_35000
glm3-3-dataset-Rank64_11000  glm3-3-dataset-Rank64_36000
glm3-3-dataset-Rank64_12000  glm3-3-dataset-Rank64_37000
glm3-3-dataset-Rank64_13000  glm3-3-dataset-Rank64_38000
glm3-3-dataset-Rank64_14000  glm3-3-dataset-Rank64_39000
glm3-3-dataset-Rank64_15000  glm3-3-dataset-Rank64_4000
glm3-3-dataset-Rank64_16000  glm3-3-dataset-Rank64_40000
glm3-3-dataset-Rank64_17000  glm3-3-dataset-Rank64_41000
glm3-3-dataset-Rank64_18000  glm3-3-dataset-Rank64_42000
glm3-3-dataset-Rank64_19000  glm3-3-dataset-Rank64_43000
glm3-3-dataset-Rank64_2000   glm3-3-dataset-Rank64_44000
glm3-3-dataset-Rank64_20000  glm3-3-dataset-Rank64_45000
glm3-3-dataset-Rank64_21000  glm3-3-dataset-Rank64_46000
glm3-3-dataset-Rank64_22000  glm3-3-dataset-Rank64_47000
glm3-3-dataset-Rank64_23000  glm3-3-dataset-Rank

In [2]:
class CFG:
    model_path = '/root/autodl-tmp/weights/chatglm3-6b'
    data_path = '/root/autodl-tmp/dataset/psychology-dataset/data/train.jsonl'
    output_dir = '/root/autodl-tmp/checkpoints/glm3-3-dataset-Rank64_15000'
    #output_dir = '/root/autodl-tmp/checkpoints/glm3-single_query_turbo3'
    #output_dir = '/root/autodl-tmp/checkpoints/glm3-full_query_turbo3'
    
    num_train_epochs = 5
    batch_size = 8
    max_tokens = 192
    max_query = 64
    lr = 1e-5
    warm_up_steps = 200

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, get_linear_schedule_with_warmup
from torch.nn import CrossEntropyLoss

from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig

import sys
import json
import pandas as pd
from tqdm import tqdm

In [4]:
sys.path.append('/root/tuning_space/Components/')
import interact
import model_tools
from Static import prompt_dict, st, si

In [5]:
def merge_lora(base_model_path, lora_path):
    # 载入基座模型
    base_model = AutoModel.from_pretrained(base_model_path, trust_remote_code=True).cuda().half()
    # 暂存用以验证权重是否改变
    first_weight = base_model.transformer.encoder.layers[0].self_attention.query_key_value.weight
    first_weight_old = first_weight.clone()
    
    # 载入lora结构的模型
    lora_model = PeftModel.from_pretrained(base_model, lora_path)
    
    # 合并lora结构
    lora_model = lora_model.merge_and_unload()
    lora_model.train(False)
    
    # 验证结构
    assert not torch.allclose(first_weight_old, first_weight), 'Weight Should Change after Lora Merge'
    
    # 给模型改名
    deloreanized_sd = {
        k.replace("base_model.model.", ""): v
        for k, v in lora_model.state_dict().items()
        if "lora" not in k
    }
    
    return lora_model

In [6]:
%%time
tokenizer=AutoTokenizer.from_pretrained(CFG.model_path, trust_remote_code=True)
model=merge_lora(CFG.model_path, CFG.output_dir)
#model = AutoModel.from_pretrained(CFG.model_path, trust_remote_code=True).cuda().half()#float()
model = model.eval()

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

CPU times: user 7.11 s, sys: 17.4 s, total: 24.5 s
Wall time: 12.2 s


In [7]:
#model_tools.chat(model, tokenizer)

In [8]:
import os
import platform
from transformers import AutoTokenizer, AutoModel

os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False

welcome_prompt = "欢迎使用TanTaili微调模型（基于ChatGLM3），输入内容即可进行对话，clear 清空对话历史，stop 终止程序"


def build_prompt(history):
    prompt = welcome_prompt
    for query, response in history:
        prompt += f"\n\n用户：{query}"
        prompt += f"\n\nTanTaili-6B：{response}"
    return prompt


def main():
    past_key_values, history = None, []
    global stop_stream
    print(welcome_prompt)
    while True:
        query = input("\n用户：")
        if query.strip() == "stop":
            break
        if query.strip() == "clear":
            past_key_values, history = None, []
            os.system(clear_command)
            print(welcome_prompt)
            continue
        print("\nChatGLM：", end="")
        current_length = 0
        for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history,
                                                                    do_sample=False, 
                                                                    temperature=0.1,
                                                                    top_p=1,
                                                                    past_key_values=past_key_values,
                                                                    return_past_key_values=True,
                                                                    repetition_penalty=1.0,
                                                                   ):
            if stop_stream:
                stop_stream = False
                break
            else:
                try:
                    print(response[current_length:], end="", flush=True)
                except:
                    break
                current_length = len(response)
        print("")


#if __name__ == "__main__":
    #main()
    
    
def inference(text):
    response, history = model.chat(tokenizer, text, history=[])
    return response

In [9]:
import gradio as gr

demo = gr.Interface(
    fn=inference,
    inputs=["text"],
    outputs=["text"],
)

demo.launch(share=True)



Running on local URL:  http://127.0.0.1:7860

Could not create share link. Missing file: /root/miniconda3/lib/python3.8/site-packages/gradio/frpc_linux_amd64_v0.2. 

Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps: 

1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.2/frpc_linux_amd64
2. Rename the downloaded file to: frpc_linux_amd64_v0.2
3. Move the file to this location: /root/miniconda3/lib/python3.8/site-packages/gradio


