<a href="https://colab.research.google.com/github/Erickrus/llm/blob/main/chatglm3_lora_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 单卡GPU 进行 ChatGLM3-6B模型 LORA 高效微调
本 Cookbook 将带领开发者使用 `AdvertiseGen` 对 ChatGLM3-6B 数据集进行 lora微调，使其具备专业的广告生成能力。

## 硬件需求
显存：24GB
显卡架构：安培架构（推荐）
内存：16GB


COLAB 环境 V100

RAM: 2.9 /16 GB

GPU RAM: 15.5/16.0 GB

In [None]:
#@title clone repository
#@markdown
!git clone https://github.com/THUDM/ChatGLM3/

Cloning into 'ChatGLM3'...
remote: Enumerating objects: 1333, done.[K
remote: Counting objects: 100% (33/33), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 1333 (delta 9), reused 22 (delta 8), pack-reused 1300[K
Receiving objects: 100% (1333/1333), 17.37 MiB | 18.64 MiB/s, done.
Resolving deltas: 100% (749/749), done.


## 1. 准备数据集
我们使用 AdvertiseGen 数据集来进行微调。从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 AdvertiseGen 数据集，将解压后的 AdvertiseGen 目录放到本目录的 `/data/` 下, 例如。
> /media/zr/Data/Code/ChatGLM3/finetune_demo/data/AdvertiseGen

接着，运行本代码来切割数据集

In [None]:
#@title download dataset
#@markdown
!pip3 install -U gdown
!gdown 13_vf0xRTQsyneRKdD1bZIr93vBGOczrk
!tar -xzvf AdvertiseGen.tar.gz
!mkdir -p /content/ChatGLM3/finetune_demo/data
!mv AdvertiseGen /content/ChatGLM3/finetune_demo/data/AdvertiseGen

Collecting gdown
  Downloading gdown-5.1.0-py3-none-any.whl (17 kB)
Installing collected packages: gdown
  Attempting uninstall: gdown
    Found existing installation: gdown 4.7.3
    Uninstalling gdown-4.7.3:
      Successfully uninstalled gdown-4.7.3
Successfully installed gdown-5.1.0
Downloading...
From: https://drive.google.com/uc?id=13_vf0xRTQsyneRKdD1bZIr93vBGOczrk
To: /content/AdvertiseGen.tar.gz
100% 17.1M/17.1M [00:00<00:00, 136MB/s]
AdvertiseGen/
AdvertiseGen/train.json
AdvertiseGen/dev.json


In [None]:
#@title download weights
#@markdown

!cd /content && git lfs clone https://huggingface.co/THUDM/chatglm3-6b

          with new flags from 'git clone'

'git clone' has been updated in upstream Git to have comparable
speeds to 'git lfs clone'.
Cloning into 'chatglm3-6b'...
remote: Enumerating objects: 128, done.[K
remote: Counting objects: 100% (125/125), done.[K
remote: Compressing objects: 100% (124/124), done.[K
remote: Total 128 (delta 62), reused 0 (delta 0), pack-reused 3[K
Receiving objects: 100% (128/128), 51.37 KiB | 25.68 MiB/s, done.
Resolving deltas: 100% (62/62), done.


In [None]:
%cd /content/ChatGLM3/finetune_demo

/content/ChatGLM3/finetune_demo


In [None]:
#@title prepare dataset
#@markdown

import json
from typing import Union
from pathlib import Path


def _resolve_path(path: Union[str, Path]) -> Path:
    return Path(path).expanduser().resolve()


def _mkdir(dir_name: Union[str, Path]):
    dir_name = _resolve_path(dir_name)
    if not dir_name.is_dir():
        dir_name.mkdir(parents=True, exist_ok=False)


def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]):
    def _convert(in_file: Path, out_file: Path):
        _mkdir(out_file.parent)
        with open(in_file, encoding='utf-8') as fin:
            with open(out_file, 'wt', encoding='utf-8') as fout:
                for line in fin:
                    dct = json.loads(line)
                    sample = {'conversations': [{'role': 'user', 'content': dct['content']},
                                                {'role': 'assistant', 'content': dct['summary']}]}
                    fout.write(json.dumps(sample, ensure_ascii=False) + '\n')

    data_dir = _resolve_path(data_dir)
    save_dir = _resolve_path(save_dir)

    train_file = data_dir / 'train.json'
    if train_file.is_file():
        out_file = save_dir / train_file.relative_to(data_dir)
        _convert(train_file, out_file)

    dev_file = data_dir / 'dev.json'
    if dev_file.is_file():
        out_file = save_dir / dev_file.relative_to(data_dir)
        _convert(dev_file, out_file)


convert_adgen('data/AdvertiseGen', 'data/AdvertiseGen_fix')

## 2. 使用命令行开始微调,我们使用 lora 进行微调
接着，我们仅需要将配置好的参数以命令行的形式传参给程序，就可以使用命令行进行高效微调，这里将 `/media/zr/Data/Code/ChatGLM3/venv/bin/python3` 换成你的 python3 的绝对路径以保证正常运行。

In [None]:
!pip3 install -r /content/ChatGLM3/requirements.txt

Collecting protobuf>=4.25.3 (from -r /content/ChatGLM3/requirements.txt (line 3))
  Downloading protobuf-5.26.1-cp37-abi3-manylinux2014_x86_64.whl (302 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/302.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━[0m [32m163.8/302.8 kB[0m [31m4.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.8/302.8 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
Collecting cpm_kernels>=1.0.11 (from -r /content/ChatGLM3/requirements.txt (line 6))
  Downloading cpm_kernels-1.0.11-py3-none-any.whl (416 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m416.6/416.6 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
Collecting gradio>=4.19.2 (from -r /content/ChatGLM3/requirements.txt (line 8))
  Downloading gradio-4.26.0-py3-none-any.whl (17.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1

In [None]:
!pip3 install -q ruamel_yaml datasets peft
!pip3 install -r /content/ChatGLM3/finetune_demo/requirements.txt


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/117.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m112.6/117.8 kB[0m [31m3.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.8/117.8 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.1/199.1 kB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m526.7/526.7 kB[0m [31m44.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m16.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m23.0 MB/s[0m eta [36m0:00:00[0m
[2K    

Modify /content/ChatGLM3/finetune_demo/configs/lora.yaml: L10

from:
  `per_device_train_batch_size: 4`

to:
  `per_device_train_batch_size: 2`

to support 16G

In [None]:
%cd /content/ChatGLM3/finetune_demo
%env CUDA_VISIBLE_DEVICES=0
!python3 finetune_hf.py  data/AdvertiseGen_fix  /content/chatglm3-6b  configs/lora.yaml

/content/ChatGLM3/finetune_demo
env: CUDA_VISIBLE_DEVICES=0
2024-04-11 01:44:57.640294: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-11 01:44:57.640350: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-11 01:44:57.641860: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Setting eos_token is not supported, use the default one.
Setting pad_token is not supported, use the default one.
Setting unk_token is not supported, use the default one.
Loading checkpoint shards: 100% 7/7 [00:04<00:00,  1.51it/s]
trainable params: 1,949,696 || all params: 6,245,533,696 || trainable%: 0

## 3. 使用微调的数据集进行推理
在完成微调任务之后，我们可以查看到 `output` 文件夹下多了很多个`checkpoint-*`的文件夹，这些文件夹代表了训练的轮数。
我们选择最后一轮的微调权重，并使用inference进行导入。

In [None]:
!ls output/

checkpoint-2000  runs


In [None]:
%cd /content/ChatGLM3/finetune_demo
%env CUDA_VISIBLE_DEVICES=0
!python3 inference_hf.py output/checkpoint-2000/ --prompt "类型#裙*版型#显瘦*材质#网纱*风格#性感*裙型#百褶*裙下摆#压褶*裙长#连衣裙*裙衣门襟#拉链*裙衣门襟#套头*裙款式#拼接*裙款式#拉链*裙款式#木耳边*裙款式#抽褶*裙款式#不规则"

/content/ChatGLM3/finetune_demo
env: CUDA_VISIBLE_DEVICES=0
Loading checkpoint shards: 100% 7/7 [00:05<00:00,  1.35it/s]
Setting eos_token is not supported, use the default one.
Setting pad_token is not supported, use the default one.
Setting unk_token is not supported, use the default one.
2024-04-11 02:22:33.739465: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-11 02:22:33.739525: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-11 02:22:33.741092: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
这款连衣裙采用了显瘦的套头设计，不规则的衣摆，搭配上木耳边的设计，更显少女的可爱。而百褶的设计，更显优雅。再加上网纱的拼接，更是性感，整体修身版型

## 4. 总结
到此位置，我们就完成了使用单张 GPU Lora 来微调 ChatGLM3-6B 模型，使其能生产出更好的广告。
在本章节中，你将会学会：
+ 如何使用模型进行 Lora 微调
+ 微调数据集的准备和对齐
+ 使用微调的模型进行推理