<a href="https://colab.research.google.com/github/WSH032/ChatGLM-6B/blob/main/Colab_ChatGLM_ptuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title #（一）克隆ChatGLM-6B的库、安装依赖

!git clone https://github.com/THUDM/ChatGLM-6B/
%cd /content/ChatGLM-6B/
print(f"正在安装依赖，请耐心等待")
!pip install --upgrade -r requirements.txt  > /dev/null 2>&1
!pip install rouge_chinese nltk jieba datasets  > /dev/null 2>&1
print(f"依赖安装完成")

# (二) txt转换成训练集、验证集

In [None]:
#@title 通过对话生成txt文件
import gradio as gr
import os


#@markdown WebUI里显示的回答行数
max_talk_num = 20 #@param {type:"number"}
#@markdown txt文件输出路径
txt_output_dir = "/content/drive/MyDrive/ChatGLM/input/" #@param {type:"string"}
os.makedirs(txt_output_dir, exist_ok=True)    #创建该路径防止报错
ask_txt_path = os.path.join(txt_output_dir,"ask.txt")
ans_txt_path = os.path.join(txt_output_dir,"ans.txt")

#生成对话txt文件
def say2txt(ask, ans):
  def try_open_txt(txt_path):
    try:
      with open(txt_path, 'r', encoding='utf-8') as f:
        txt_content = f.read()
        return txt_content
    except FileNotFoundError:
      with open(txt_path, 'w', encoding='utf-8') as f:
        return ""
    
  #重新读入内容
  ask_txt_content = try_open_txt(ask_txt_path)
  ans_txt_content = try_open_txt(ans_txt_path)

  #只保留有内容的行
  ans_lines = ans.split("\n")
  ans_lines = [line for line in ans_lines if line.strip() != ""]
  l = len(ans_lines)
  #讲ask补充至和ans一样多
  ask_line = ask +"\n"
  ask = ask_line * l

  #重新连接变成行的ans
  ans = ""
  for i in range(l):
      ans += ( ans_lines[i] + "\n" )
  #与之前内容合并
  if ask:
      ask_txt_content += ask
  if ans:
      ans_txt_content += ans

  with open(ask_txt_path, 'w', encoding='utf-8') as f:
      f.write(ask_txt_content)
  with open(ans_txt_path, "w", encoding="utf-8") as f:
      f.write(ans_txt_content)
  return "写入成功"

#清空ask和ans输入框
def clear_text():
  return (*[gr.update(value=''), gr.update(value=''), gr.update(value='清空成功')],)

#删除txt文件


def del_txt(): 
  try: os.remove(ask_txt_path) 
  except FileNotFoundError: pass 
  try: os.remove(ans_txt_path) 
  except FileNotFoundError: pass 
  return "删除成功"

with gr.Blocks() as demo:
  with gr.Row():
    with gr.Column(scale=1):
      button_creat = gr.Button("生成txt")
    with gr.Column(scale=1):
      button_clear = gr.Button("清空问题与回答")
    with gr.Column(scale=1):
      button_del = gr.Button("删除txt文件")
  with gr.Row():
    with gr.Column(scale=3):
      text_ask = gr.Textbox(lines=1, show_label=False, placeholder="问题")
    with gr.Column(scale=1):
      output = gr.Textbox(lines=1, show_label=False, placeholder="输出")
  with gr.Row():
    with gr.Column(scale=5):
      text_ans = gr.Textbox(lines=max_talk_num, show_label=False, placeholder="回答")
  button_creat.click(say2txt, [text_ask, text_ans], output)
  button_clear.click(clear_text, [], [text_ask, text_ans, output])
  button_del.click(del_txt, [], output)
demo.queue().launch()

In [None]:
#@title txt转换成json

def txt2json(ask_txt_path, ans_txt_path, json_dir, sample_ratio, random_sample=True, shuffle_txt=True):
    
    import os
    import json
    import random
    import math

    def read(file_path):
        #使用with语句打开文件，并读取内容
        with open(file_path, 'r', encoding='utf-8') as f:
            data = f.readlines() # 读取所有行
        #删去换行符
        data = [line.strip("\n") for line in data]
        data = [line.strip("/n") for line in data]
        return data # 返回结果

    #随机采样出验证集（除非指定索引index），n为采样行数
    def sample_list(list, n, index=None):
        temp = list[:]
        # Generate n random index
        if not index:
            index = random.sample(range(len(list)), n)
        # Remove elements by index and store them in a list
        sample_elements = [temp.pop(i) for i in sorted(index, reverse=True)]
        last_elements = temp
        return sample_elements, last_elements, index

    #读取内容
    ask_content = read(ask_txt_path)
    ans_content = read(ans_txt_path)
    

    #计算长度是否匹配
    l_ask = len(ask_content)
    l_ans = len(ans_content)
    l_diff = l_ask - l_ans
    l = min(l_ask,l_ans)

    #删去多余行
    if l_diff > 0 :
        print(f"ask数量多于ans，ask末尾的{ abs(l_diff) } 行将被删除")
        del ask_content[-abs(l_diff):]
    elif l_diff < 0 :
        print(f"ans数量多于ask，ans末尾的{ abs(l_diff) } 行将被删除")
        del ans_content[-abs(l_diff):]
    print(f"一共{l_ask}行数据")

    #打乱顺序
    if shuffle_txt:
        random.shuffle(ask_content)
        random.shuffle(ans_content)

    #用采样率计算采样行数
    sample_num = math.floor(l * sample_ratio)

    #不随机采样，则指定最后n行为验证集
    if random_sample:
        index = None
    else:
        index = range(l-sample_num, l)

    #采样得到验证集和训练集
    [ask_validation, ask_train, index] = sample_list(ask_content, sample_num, index)
    [ans_validation, ans_train, index] = sample_list(ans_content, sample_num, index)
    print(f"采用其中{len(index)}行做为验证集")

    def write_json(json_path, ask_content, ans_content):
        with open(json_path, 'w', encoding='utf-8') as f:
            for ask_content_row, ans_content_row in zip(ask_content, ans_content):
                data = {"ask": ask_content_row, "ans": ans_content_row} # 创建字典
                json_data = json.dumps(data, ensure_ascii=False) # 将字典转换成json字符串
                f.write(json_data + "\n") # 写入文件并换行
    
    os.makedirs(json_dir, exist_ok=True)
    train_json_path = os.path.join(json_dir,"train.json")
    validation_json_path = os.path.join(json_dir,"validation.json")
    write_json(train_json_path, ask_train, ans_train)
    write_json(validation_json_path, ask_validation, ans_validation)


#@markdown ask文件、ans文件路径，输出文件夹路径，验证集率，是否开启随机采样做验证集，是否打乱顺序
ask_txt_path = "/content/drive/MyDrive/ChatGLM/input/ask.txt" #@param {type:"string"}
ans_txt_path = "/content/drive/MyDrive/ChatGLM/input/ans.txt" #@param {type:"string"}
sample_ratio = 0.101 #@param {type:"slider", min:0.001, max:0.999, step:0.001}
random_sample = True #@param {type:"boolean"}
shuffle_txt = True #@param {type:"boolean"}
json_dir = "/content/ChatGLM-6B/train" #@param {type:"string"}

txt2json(ask_txt_path, ans_txt_path, json_dir, sample_ratio, random_sample=random_sample, shuffle_txt=shuffle_txt)

# (三)设置参数，开始训练

In [None]:
#@title 训练
#@markdown  训练集文件、验证集文件地址, 
train_file = "/content/ChatGLM-6B/train/train.json" #@param {type:"string"}
validation_file = "/content/ChatGLM-6B/train/validation.json" #@param {type:"string"}


#@markdown  soft prompt长度, 训练的学习率, batch大小`colab普通用户最多24`, 梯度累计算次数`一般使其与batch乘积为一常数`, 量化等级`不填则采用fp16`
PRE_SEQ_LEN = 8 #@param {type:"number"}
LR = 1e-2 #@param {type:"number"}
batch_size = 24 #@param {type:"number"}
gradient_accumulation_steps = 1 #@param {type:"number"}
quantization_bit = 4 #@param ["", 4, 8]

#@markdown 训练底模`Colab免费用户只能使用int4和qe模型`，或者填入自定义模型路径`将会覆盖预设模型选择`
model_path = "THUDM/chatglm-6b-int4-qe" #@param ["THUDM/chatglm-6b", "THUDM/chatglm-6b-int4", "THUDM/chatglm-6b-int4-qe"]
your_model_path = "" #@param {type:"string"}
#用自定义路径覆盖预设
if your_model_path:
  model_path = your_model_path

#@markdown 最大训练步数，日志输出步数，每n步保存一次checkpoints
max_steps = 1500 #@param {type:"number"}
logging_steps = 10 #@param {type:"number"}
save_steps = 500 #@param {type:"number"} 

#@markdown 模型输出路径
output_dir = "/content/drive/MyDrive/ChatGLM/output" #@param {type:"string"}
import os
output_dir = os.path.join(output_dir, f"adgen-chatglm-6b-pt-{PRE_SEQ_LEN}-")

!python /content/ChatGLM-6B/ptuning/main.py \
    --do_train \
    --train_file {train_file} \
    --validation_file {validation_file} \
    --prompt_column ask \
    --response_column ans \
    --overwrite_cache \
    --model_name_or_path {model_path} \
    --output_dir {output_dir} \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size {batch_size} \
    --per_device_eval_batch_size {batch_size} \
    --gradient_accumulation_steps {gradient_accumulation_steps} \
    --predict_with_generate \
    --max_steps {max_steps} \
    --logging_steps {logging_steps} \
    --save_steps {save_steps} \
    --learning_rate {LR} \
    --pre_seq_len {PRE_SEQ_LEN} \
    --quantization_bit {quantization_bit}