In [None]:
from datasets import load_from_disk
from transformers import (
    AutoModel,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
    HfArgumentParser
)
import torch
from peft import TaskType, LoraConfig, get_peft_model
import os
import numpy as np
from dataclasses import field, dataclass
from typing import Union, List, Optional
from loguru import logger

# prepare arguments 

In [2]:
@dataclass
class FinetuneArguments():
    dataset_path: str = field(default='data/Med_datasets.jsonl')
    model_path: str = field(default='THUDM/chatglm2-6b')
    label_pad_token_id: int = field(default=-100)
    load_in_4bit: bool = field(default=True)
    bnb_4bit_quant_type: str = field(default='nf4')
    bnb_4bit_compute_dtype: str = field(default='float32')
    seed: int = field(default=42)
    resume_from_checkpoint: str = field(default=None)
    final_model_path: str = field(default='QLora_Adapter_THUDM_chatglm2-6b/finally_adapter')

In [3]:
@dataclass
class LoraArguments():
    target_modules: Union[List[int], str] = field(default='query_key_value')
    r: int = field(default=12)
    lora_alpha: int = field(default=12)
    lora_dropout: float = field(default=0.05)
    bias: str = field(default='none')
    inference_mode: bool = field(default=False)
    layers_to_transform: List[int] = field(default=None)
    layers_pattern: str = field(default=None)
    Adapter_name: str = field(default='LoraAdapter')

In [4]:
finetune_args = HfArgumentParser(FinetuneArguments)
finetune_args, = finetune_args.parse_json_file(json_file='args_file/finetune_args.json')
logger.info(f"current_finetune_args--->{finetune_args}")

[32m2023-07-15 08:44:22.404[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mcurrent_finetune_args--->FinetuneArguments(dataset_path='data/Med_datasets.jsonl', model_path='THUDM/chatglm2-6b', label_pad_token_id=-100, load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype='float32', seed=42, resume_from_checkpoint=None, final_model_path='Lora_Adapter_THUDM_chatglm2-6b/finally_adapter')[0m


In [5]:
train_args = HfArgumentParser(TrainingArguments)
train_args, = train_args.parse_json_file(json_file='args_file/train_args.json')
logger.info(f"current_train_args--->{train_args}")

[32m2023-07-15 08:44:23.866[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mcurrent_train_args--->TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=False,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=100,
evaluation_strategy=IntervalStrategy.STEPS,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'fsdp_min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greate

In [6]:
lora_args = HfArgumentParser(LoraArguments)
lora_args, = lora_args.parse_json_file(json_file='args_file/lora_args.json')
logger.info(f"current_lora_args--->{lora_args}")

[32m2023-07-15 08:44:26.086[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mcurrent_lora_args--->LoraArguments(target_modules=['query_key_value'], r=12, lora_alpha=24, lora_dropout=0.05, bias='none', inference_mode=False, layers_to_transform=None, layers_pattern=None, Adapter_name='LoraAdapter')[0m


# process dataset

In [7]:
logger.info(f"from {finetune_args.dataset_path} loading datasets and tokenize datasets")
datasets = load_from_disk(dataset_path=finetune_args.dataset_path)

[32m2023-07-15 08:44:28.355[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mfrom data/Med_datasets.jsonl loading datasets and tokenize datasets[0m


In [8]:
datasets

DatasetDict({
    train: Dataset({
        features: ['context', 'target'],
        num_rows: 6622
    })
    valid: Dataset({
        features: ['context', 'target'],
        num_rows: 1000
    })
})

In [9]:
tokenizer = AutoTokenizer.from_pretrained(finetune_args.model_path,
                                          trust_remote_code=True)

## tokenize data 

In [10]:
def tokenize_dataset(data, label_pad_token_id=finetune_args.label_pad_token_id):
    instruction = data['context']
    target = data['target']
    instruction_ids = tokenizer.encode(instruction, add_special_tokens=True)
    target_ids = tokenizer.encode(target, add_special_tokens=False)
    input_ids = instruction_ids + target_ids + [tokenizer.eos_token_id]
    labels = ([label_pad_token_id] * len(instruction_ids) + target_ids
              + [tokenizer.eos_token_id])
    return {'input_ids': input_ids, 'labels': labels}

In [11]:
remove_columns = datasets['train'].column_names
datasets_train = datasets['train'].map(tokenize_dataset,
                                       remove_columns=remove_columns)

Loading cached processed dataset at /mnt/workspace/chatglm2-6b-AdaLoRA/data/Med_datasets.jsonl/train/cache-b3767ca69873ef7a.arrow


In [12]:
datasets_train

Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 6622
})

In [13]:
remove_columns = datasets['valid'].column_names
datasets_valid = datasets['valid'].map(tokenize_dataset,
                                       remove_columns=remove_columns)

Loading cached processed dataset at /mnt/workspace/chatglm2-6b-AdaLoRA/data/Med_datasets.jsonl/valid/cache-fe09c3d2a66d4e85.arrow


In [14]:
datasets_valid

Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 1000
})

In [15]:
logger.info(f"datasets_train--->{datasets_train.select(range(2))[0]}")

[32m2023-07-15 08:44:36.905[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mdatasets_train--->{'input_ids': [64790, 64792, 30910, 39501, 31211, 33182, 32103, 32834, 31930, 32184, 31123, 55073, 31793, 43021, 32834, 34301, 31788, 31123, 37881, 36266, 54530, 33287, 31123, 38545, 32548, 54567, 33287, 43082, 31123, 54535, 55124, 55643, 55216, 55802, 54820, 31155, 13, 31639, 31211, 31623, 32016, 31699, 32044, 35531, 45331, 31755, 37079, 54706, 31123, 32082, 32066, 31751, 35531, 45331, 55251, 57252, 31123, 31755, 34529, 57194, 54557, 54679, 31123, 31843, 35531, 43114, 31123, 34802, 32066, 54538, 31017, 54831, 54900, 32066, 31751, 36082, 55021, 54629, 31123, 54695, 42693, 31831, 31642, 32222, 31514, 13, 33287, 30954, 30910, 47383, 35531, 45331, 55251, 57252, 31201, 57194, 54557, 54679, 31201, 35531, 43114, 41641, 31123, 31975, 31017, 54831, 54900, 32066, 31951, 31123, 51634, 54541, 35531, 44234, 31155, 2], 'labels': [-100, -100, -100, -100, -100, -100, -100, -100

# load model and quantize

In [16]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=finetune_args.load_in_4bit,
    bnb_4bit_quant_type=finetune_args.bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=finetune_args.bnb_4bit_compute_dtype
)

In [17]:
logger.info(f"from {finetune_args.model_path} loading model")
model = AutoModel.from_pretrained(finetune_args.model_path,
                                  trust_remote_code=True,
                                  device_map='auto',
                                  quantization_config=quantization_config,
                                  )

[32m2023-07-15 08:44:40.197[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mfrom THUDM/chatglm2-6b loading model[0m
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████| 7/7 [00:15<00:00,  2.27s/it]


In [18]:
model

ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (embedding): Embedding(
      (word_embeddings): Embedding(65024, 4096)
    )
    (rotary_pos_emb): RotaryEmbedding()
    (encoder): GLMTransformer(
      (layers): ModuleList(
        (0-27): 28 x GLMBlock(
          (input_layernorm): RMSNorm()
          (self_attention): SelfAttention(
            (query_key_value): Linear4bit(in_features=4096, out_features=4608, bias=True)
            (core_attention): CoreAttention(
              (attention_dropout): Dropout(p=0.0, inplace=False)
            )
            (dense): Linear4bit(in_features=4096, out_features=4096, bias=False)
          )
          (post_attention_layernorm): RMSNorm()
          (mlp): MLP(
            (dense_h_to_4h): Linear4bit(in_features=4096, out_features=27392, bias=False)
            (dense_4h_to_h): Linear4bit(in_features=13696, out_features=4096, bias=False)
          )
        )
      )
      (final_layernorm): RMSNorm()
    )
    (output_la

#  prepare quantization model for train

In [19]:
logger.info("prepare model for training")

[32m2023-07-15 08:45:08.259[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mprepare model for training[0m


In [20]:
for name, param in model.named_parameters():
    # freeze model parameters
    param.requires_grad = False

# cast all non INT8 parameters to fp32
for param in model.parameters():
    if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
        param.data = param.data.to(torch.float32)

# For backward compatibility
model.enable_input_require_grads()
# Enable gradient checkpoint
model.gradient_checkpointing_enable()
logger.info("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`")
# `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`
model.config.use_cache = False
# but during inference make sure to set it back to True
logger.info("but during inference make sure to set it back to True")

[32m2023-07-15 08:45:09.779[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m14[0m - [1m`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`[0m
[32m2023-07-15 08:45:09.780[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mbut during inference make sure to set it back to True[0m


# Lora Config 

In [21]:
Lora_config = LoraConfig(
    r=lora_args.r,
    target_modules=lora_args.target_modules,
    lora_alpha=lora_args.lora_alpha,
    lora_dropout=lora_args.lora_dropout,
    bias=lora_args.bias,
    inference_mode=lora_args.inference_mode,
    task_type=TaskType.CAUSAL_LM,
    layers_to_transform=lora_args.layers_to_transform,
    layers_pattern=lora_args.layers_pattern
)
logger.info(f"Lora_config--->{Lora_config}")

[32m2023-07-15 08:45:12.082[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1mLora_config--->LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=12, target_modules=['query_key_value'], lora_alpha=24, lora_dropout=0.05, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None)[0m


In [22]:
model = get_peft_model(model, Lora_config, lora_args.Adapter_name)
logger.info(f"lora_model--->{model}")

[32m2023-07-15 08:45:17.748[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mlora_model--->PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): ChatGLMForConditionalGeneration(
      (transformer): ChatGLMModel(
        (embedding): Embedding(
          (word_embeddings): Embedding(65024, 4096)
        )
        (rotary_pos_emb): RotaryEmbedding()
        (encoder): GLMTransformer(
          (layers): ModuleList(
            (0-27): 28 x GLMBlock(
              (input_layernorm): RMSNorm()
              (self_attention): SelfAttention(
                (query_key_value): Linear4bit(
                  in_features=4096, out_features=4608, bias=True
                  (lora_dropout): ModuleDict(
                    (LoraAdapter): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (LoraAdapter): Linear(in_features=4096, out_features=12, bias=False)
                  )
                  (lora_B)

In [23]:
logger.info("print trainable parameters")
model.print_trainable_parameters()

[32m2023-07-15 08:45:22.450[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mprint trainable parameters[0m


trainable params: 2,924,544 || all params: 3,391,236,096 || trainable%: 0.08623828943816479


# train 

In [24]:
# compute model output perplexity
def compute_metrics(loss):
    loss_mean = loss.mean()
    perplexity = np.exp2(loss_mean)
    return {"perplexity": perplexity}

In [27]:
class LoraTrainer(Trainer):

    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
        # only save Lora adapter
        if output_dir is None:
            output_dir = self.args.output_dir
        self.model.save_pretrained(output_dir)
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

In [28]:
data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        pad_to_multiple_of=8,
        padding=True,
)

In [29]:
trainer = LoraTrainer(
    model=model,
    args=train_args,
    train_dataset=datasets_train,
    eval_dataset=datasets_valid,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set to `True` to avoid any unexpected behavior such as device placement mismatching.
The model is loaded in 8-bit precision. To train this model you need to add additional modules inside the model such as adapters using `peft` library and freeze the model weights. Please check  the examples in https://github.com/huggingface/peft for more details.


In [30]:
resume_from_checkpoint = finetune_args.resume_from_checkpoint
if resume_from_checkpoint is not None:
    if os.path.exists(resume_from_checkpoint):
        logger.info(f'Restarting from {resume_from_checkpoint}')
        model.load_adapter(resume_from_checkpoint, lora_args.Adapter_name, subfolder=lora_args.Adapter_name)
    else:
        raise Exception(f'{resume_from_checkpoint} is not a correct path!')

In [31]:
logger.info(f"start training from {resume_from_checkpoint}")

[32m2023-07-15 08:45:43.776[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mstart training from None[0m


In [34]:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

***** Running training *****
  Num examples = 6,622
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2,484
  Number of trainable parameters = 2,924,544


Step,Training Loss,Validation Loss,Perplexity
100,2.4153,2.321457,4.998368
200,2.1056,2.183702,4.54318
300,2.2151,2.144331,4.420872
400,2.1475,2.121216,4.350604
500,2.1506,2.105941,4.304783
600,2.1684,2.093386,4.267486
700,2.0514,2.083433,4.238144
800,2.1558,2.074739,4.212681
900,2.0104,2.067188,4.190691
1000,2.0638,2.062333,4.17661


***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evalua

TrainOutput(global_step=2484, training_loss=2.0947799682617188, metrics={'train_runtime': 12110.1361, 'train_samples_per_second': 1.64, 'train_steps_per_second': 0.205, 'total_flos': 5.549398492328755e+16, 'train_loss': 2.0947799682617188, 'epoch': 3.0})

In [35]:
logger.info(f"training finished, save model to {finetune_args.final_model_path}")

[32m2023-07-15 12:11:32.992[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mtraining finished, save model to Lora_Adapter_THUDM_chatglm2-6b/finally_adapter[0m


In [36]:
trainer.model.save_pretrained(finetune_args.final_model_path)

In [37]:
torch.save(trainer.args, os.path.join(finetune_args.final_model_path, "training_args.bin"))