# ChatGLM3-6B-LoRA-16bit

两种方法：from_pretrained时指定torch_dtype=torch.half（推荐）；model=model.half()

显存15.95G（batchsize为1时）

## Step1 导入相关包

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3"

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

## Step2 加载数据集

In [None]:
ds = load_dataset("lifefabric/alpaca_data_cleaned.zh", split="train[:10000]")
ds

In [None]:
ds[:3]

## Step3 数据集预处理

In [None]:
tokenizer = AutoTokenizer.from_pretrained("/node6_1/tanshuai/ZhipuAI/chatglm3-6b-base", trust_remote_code=True)
tokenizer

In [None]:
tokenizer(tokenizer.eos_token, add_special_tokens=False), tokenizer.eos_token_id

In [None]:
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = "\n".join([example["instruction"], example["input"]]).strip()     # query
    instruction = tokenizer.build_chat_input(instruction, history=[], role="user")  # [gMASK]sop<|user|> \n query<|assistant|>
    response = tokenizer("\n" + example["output"], add_special_tokens=False)        # \n response, 缺少eos token
    input_ids = instruction["input_ids"][0].numpy().tolist() + response["input_ids"] + [tokenizer.eos_token_id]
    attention_mask = instruction["attention_mask"][0].numpy().tolist() + response["attention_mask"] + [1]
    labels = [-100] * len(instruction["input_ids"][0].numpy().tolist()) + response["input_ids"] + [tokenizer.eos_token_id]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

In [None]:
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds

In [None]:
print(tokenizer.decode(tokenized_ds[1]["input_ids"]))

In [None]:
tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"])))

## Step4 创建模型

In [None]:
# import torch
# model = AutoModelForCausalLM.from_pretrained("/node6_1/tanshuai/ZhipuAI/chatglm3-6b-base", trust_remote_code=True, torch_dtype=torch.bfloat16)

In [None]:
import torch

model = AutoModelForCausalLM.from_pretrained("/node6_1/tanshuai/ZhipuAI/chatglm3-6b-base", 
                                  trust_remote_code=True, 
                                  low_cpu_mem_usage=True,
                                  torch_dtype=torch.half,)
                                  # device_map="auto")
print(model)

In [None]:
for name, parameter in model.named_parameters():
    print(name, parameter.dtype)

## LoRA

### PEFT 1 配置文件

In [None]:
from peft import LoraConfig, TaskType, get_peft_model, PeftModel

config = LoraConfig(target_modules=["query_key_value"], modules_to_save=["post_attention_layernorm"])
config

### PEFT 2 创建模型

In [None]:
model = get_peft_model(model, config)

In [None]:
model

In [None]:
model.half()

In [None]:
for name, parameter in model.named_parameters():
    print(name, parameter.dtype)

## Step5 配置训练参数

In [None]:
args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    logging_steps=1,
    num_train_epochs=1,
    learning_rate=1e-4,
    remove_unused_columns=False,
    save_strategy="epoch",
    adam_epsilon=1e-4    # 当整个模型都是半精度时，需要将adam_epsilon调大
)

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=tokenized_ds.select(range(5000)),
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
)

## Step7 模型训练

In [None]:
trainer.train()

In [None]:
from safetensors import safe_open

with safe_open("./chatbot/checkpoint-78/adapter_model.safetensors", framework="pt") as f:
    for key in f.keys():
        if ".0.post_attention_layernorm" in key:
            print(key)
            print(f.get_tensor(key))

## Step8 模型推理

In [None]:
model.eval()
print(model.chat(tokenizer, "数学考试怎么考高分？", history=[]))

In [None]:
model.chat(tokenizer, "有什么考试技巧？", history=[])