In [1]:
import time
import mindspore as ms
import numpy as np
import argparse
from mindformers.models.glm import GLMConfig, GLMChatModel, GLMChatModelWithLora
from mindformers.models.glm.chatglm_6b_tokenizer import ChatGLMTokenizer
from mindformers.models.glm.glm_processor import process_response
from mindformers.pet.pet_config import LoraConfig

In [2]:
def parse_arguments(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('--seq_length', default=1024, type=int, help='Which device to run service.')
    parser.add_argument('--device_id', default=0, type=int, help='Which device to run service.')
    parser.add_argument('--checkpoint_path', type=str, default='/home/ma-user/work/mindglm/mindformers/output/checkpoint/rank_0/glm-6b-lora_rank_0_1-network.ckpt', help='Checkpoint file to load on.')
    parser.add_argument('--vocab_path', type=str, default='/home/ma-user/work/mindglm/checkpoint_download/glm/ice_text.model', help='Vocab file to load on.')
    parser.add_argument('--is_lora', type=str, default='true',help='Whether is lora model.')
    return parser.parse_args(argv)

In [3]:
args = parse_arguments(['--device_id','0'])

In [4]:
if args.is_lora.lower() == "true":
    is_lora = True
else:
    is_lora = False

config = GLMConfig(
    position_encoding_2d=True,
    use_past=True,
    is_sample_acceleration=True,
)

pet_config = LoraConfig(
    lora_rank=8,
    lora_alpha=32,
    lora_dropout=0.1
)

In [5]:
ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)

In [6]:
if is_lora:
    config.pet_config = pet_config
    model = GLMChatModelWithLora(config)
# else:
#     model = GLMChatModel(config)



2023-09-09 16:09:51,614 - mindformers - INFO - model built, but weights is unloaded, since the config has no checkpoint_name_or_path attribute or checkpoint_name_or_path is None.
2023-09-09 16:09:57,432 - mindformers - INFO - model built, but weights is unloaded, since the config has no checkpoint_name_or_path attribute or checkpoint_name_or_path is None.


In [7]:
model

GLMChatModelWithLora<
  (transformer): GLMModel<
    (embedding_dropout): Dropout<keep_prob=1.0>
    (word_embeddings): VocabEmbedding<>
    (layers): CellList<
      (0): DeepNormWithGLULayer<
        (input_layernorm): LayerNorm<>
        (attention): RotaryEmbeddingFP32SoftmaxSelfAttention<
          (rotary_emb): RotaryEmbedding<>
          (query_key_value): LoRADense<
            input_channels=4096, output_channels=12288, has_bias=True
            (lora_dropout): Dropout<keep_prob=0.9>
            >
          (attention_dropout): Dropout<keep_prob=1.0>
          (dense): Linear<>
          (output_dropout): Dropout<keep_prob=1.0>
          (softmax): Softmax<>
          >
        (post_attention_layernorm): LayerNorm<>
        (mlp): MLPWithGEGLU<
          (activation_func): GELU<>
          (dense_h_to_4h): Linear<>
          (dense_4h_to_h): Linear<>
          (dropout): Dropout<keep_prob=1.0>
          >
        >
      (1): DeepNormWithGLULayer<
        (input_layernorm): L

In [8]:
ms.load_checkpoint(args.checkpoint_path, model)
tokenizer = ChatGLMTokenizer(args.vocab_path)



此处是新增加的LoRA权重
因此，出现
[WARNING] ME(838466:281473066355264,MainProcess):2023-09-09-16:12:11.693.595 [mindspore/train/serialization.py:716] transformer.layers.27.value_past is not loaded.
是正常现象

In [9]:
# prompts = ["你好", "请介绍一下华为", "用Python写一个快排"]
history = []

In [10]:
config.max_decode_length=30

In [11]:
def chat(query,history=history):
    if not history:
        prompt = query
    else:
        prompt = ""
        for i, (old_query, response) in enumerate(history):
            prompt += "[Round {}]\n问：{}\n答：{}\n".format(i, old_query, response)
        prompt += "[Round {}]\n问：{}\n答：".format(len(history), query)
    inputs = tokenizer(prompt)

    start_time = time.time()
    outputs = model.generate(np.expand_dims(np.array(inputs['input_ids']).astype(np.int32), 0),
                             max_length=config.max_decode_length, do_sample=False, top_p=0.7, top_k=1)
    end_time = time.time()
    print(f'generate speed: {outputs[0].shape[0]/(end_time-start_time):.2f} tokens/s')
    response = tokenizer.decode(outputs)
    response = process_response(response[0])
    history = history + [(query, response)]
    print(response)

In [None]:
chat("您上海迪士尼乐园的会员卡号是?")