In [31]:
from datasets import load_from_disk
from transformers import (
    AutoModel,
    AutoTokenizer,
    TrainingArguments,
    DataCollatorForSeq2Seq,
    HfArgumentParser
)
from train_lora import FinetuneArguments, LoraArguments, LoraTrainer
from loguru import logger
from peft import LoraConfig, get_peft_model, TaskType
import numpy as np
import os
import torch

# prepare arguments 

In [4]:
finetune_args = HfArgumentParser(FinetuneArguments)
finetune_args, = finetune_args.parse_json_file(json_file='args_file/finetune_args.json')
logger.info(f"current_finetune_args--->{finetune_args}")

[32m2023-07-17 06:26:58.626[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mcurrent_finetune_args--->FinetuneArguments(dataset_path='data/Med_datasets.jsonl', model_path='THUDM/chatglm2-6b', label_pad_token_id=-100, load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype='float32', seed=42, resume_from_checkpoint=None, final_model_path='Lora_Adapter_THUDM_chatglm2-6b/finally_adapter')[0m


In [5]:
train_args = HfArgumentParser(TrainingArguments)
train_args, = train_args.parse_json_file(json_file='args_file/train_args.json')
logger.info(f"current_train_args--->{train_args}")

[32m2023-07-17 06:27:08.342[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mcurrent_train_args--->TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=False,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=100,
evaluation_strategy=IntervalStrategy.STEPS,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'fsdp_min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greate

In [6]:
lora_args = HfArgumentParser(LoraArguments)
lora_args, = lora_args.parse_json_file(json_file='args_file/lora_args.json')
logger.info(f"current_lora_args--->{lora_args}")

[32m2023-07-17 06:27:20.675[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mcurrent_lora_args--->LoraArguments(target_modules=['query_key_value'], r=12, lora_alpha=24, lora_dropout=0.05, bias='none', inference_mode=False, layers_to_transform=None, layers_pattern=None, Adapter_name='LoraAdapter')[0m


# process datasets 

In [7]:
logger.info(f"from {finetune_args.dataset_path} loading datasets and tokenize datasets")
datasets = load_from_disk(dataset_path=finetune_args.dataset_path)

[32m2023-07-17 06:28:03.820[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mfrom data/Med_datasets.jsonl loading datasets and tokenize datasets[0m


In [8]:
tokenizer = AutoTokenizer.from_pretrained(finetune_args.model_path,
                                          trust_remote_code=True)

In [9]:
def tokenize_dataset(data, label_pad_token_id=finetune_args.label_pad_token_id):
    instruction = data['context']
    target = data['target']
    instruction_ids = tokenizer.encode(instruction, add_special_tokens=True)
    target_ids = tokenizer.encode(target, add_special_tokens=False)
    input_ids = instruction_ids + target_ids + [tokenizer.eos_token_id]
    labels = ([label_pad_token_id] * len(instruction_ids) + target_ids
              + [tokenizer.eos_token_id])
    return {'input_ids': input_ids, 'labels': labels}

In [10]:
remove_columns = datasets['train'].column_names
datasets_train = datasets['train'].map(tokenize_dataset,
                                       remove_columns=remove_columns)

                                                                 

In [11]:
remove_columns = datasets['valid'].column_names
datasets_valid = datasets['valid'].map(tokenize_dataset,
                                       remove_columns=remove_columns)

                                                                

In [12]:
logger.info(f"datasets_train--->{datasets_train.select(range(2))[0]}")

[32m2023-07-17 06:29:08.142[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mdatasets_train--->{'input_ids': [64790, 64792, 30910, 39501, 31211, 33182, 32103, 32834, 31930, 32184, 31123, 55073, 31793, 43021, 32834, 34301, 31788, 31123, 37881, 36266, 54530, 33287, 31123, 38545, 32548, 54567, 33287, 43082, 31123, 54535, 55124, 55643, 55216, 55802, 54820, 31155, 13, 31639, 31211, 31623, 32016, 31699, 32044, 35531, 45331, 31755, 37079, 54706, 31123, 32082, 32066, 31751, 35531, 45331, 55251, 57252, 31123, 31755, 34529, 57194, 54557, 54679, 31123, 31843, 35531, 43114, 31123, 34802, 32066, 54538, 31017, 54831, 54900, 32066, 31751, 36082, 55021, 54629, 31123, 54695, 42693, 31831, 31642, 32222, 31514, 13, 33287, 30954, 30910, 47383, 35531, 45331, 55251, 57252, 31201, 57194, 54557, 54679, 31201, 35531, 43114, 41641, 31123, 31975, 31017, 54831, 54900, 32066, 31951, 31123, 51634, 54541, 35531, 44234, 31155, 2], 'labels': [-100, -100, -100, -100, -100, -100, -100, -100

# load model 

In [13]:
logger.info(f"from {finetune_args.model_path} loading model")
model = AutoModel.from_pretrained(finetune_args.model_path,
                                  trust_remote_code=True,
                                  device_map='auto',
                                  ).half()

[32m2023-07-17 06:29:49.552[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mfrom THUDM/chatglm2-6b loading model[0m
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████| 7/7 [01:12<00:00, 10.30s/it]


# prepare model for train 

In [14]:
for name, param in model.named_parameters():
    # freeze model parameters
    param.requires_grad = False

# For backward compatibility
model.enable_input_require_grads()
# Enable gradient checkpoint
model.gradient_checkpointing_enable()
logger.info("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`")
# `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`
model.config.use_cache = False
# but during inference make sure to set it back to True
logger.info("but during inference make sure to set it back to True")

[32m2023-07-17 06:32:13.183[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1m`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`[0m
[32m2023-07-17 06:32:13.185[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mbut during inference make sure to set it back to True[0m


# LoRa Config 

In [19]:
Lora_config = LoraConfig(
    r=lora_args.r,
    target_modules=lora_args.target_modules,
    lora_alpha=lora_args.lora_alpha,
    lora_dropout=lora_args.lora_dropout,
    bias=lora_args.bias,
    inference_mode=lora_args.inference_mode,
    task_type=TaskType.CAUSAL_LM,
    layers_to_transform=lora_args.layers_to_transform,
    layers_pattern=lora_args.layers_pattern
)
logger.info(f"Lora_config--->{Lora_config}")

[32m2023-07-17 06:33:20.546[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1mLora_config--->LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=12, target_modules=['query_key_value'], lora_alpha=24, lora_dropout=0.05, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None)[0m


In [20]:
model = get_peft_model(model, Lora_config, lora_args.Adapter_name)
logger.info(f"lora_model--->{model}")

[32m2023-07-17 06:33:40.857[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mlora_model--->PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): ChatGLMForConditionalGeneration(
      (transformer): ChatGLMModel(
        (embedding): Embedding(
          (word_embeddings): Embedding(65024, 4096)
        )
        (rotary_pos_emb): RotaryEmbedding()
        (encoder): GLMTransformer(
          (layers): ModuleList(
            (0-27): 28 x GLMBlock(
              (input_layernorm): RMSNorm()
              (self_attention): SelfAttention(
                (query_key_value): Linear(
                  in_features=4096, out_features=4608, bias=True
                  (lora_dropout): ModuleDict(
                    (LoraAdapter): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (LoraAdapter): Linear(in_features=4096, out_features=12, bias=False)
                  )
                  (lora_B): Mo

In [21]:
logger.info("print trainable parameters")
model.print_trainable_parameters()

[32m2023-07-17 06:34:08.527[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mprint trainable parameters[0m


trainable params: 2,924,544 || all params: 6,246,508,544 || trainable%: 0.04681885855753982


# train 

In [24]:
# compute model output perplexity
def compute_metrics(loss):
    loss_mean = loss.mean()
    perplexity = np.exp2(loss_mean)
    return {"perplexity": perplexity}

In [25]:
data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        pad_to_multiple_of=8,
        padding=True,
)

In [26]:
trainer = LoraTrainer(
    model=model,
    args=train_args,
    train_dataset=datasets_train,
    eval_dataset=datasets_valid,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set to `True` to avoid any unexpected behavior such as device placement mismatching.


In [28]:
resume_from_checkpoint = finetune_args.resume_from_checkpoint
if resume_from_checkpoint is not None:
    if os.path.exists(resume_from_checkpoint):
        logger.info(f'Restarting from {resume_from_checkpoint}')
        model.load_adapter(resume_from_checkpoint, lora_args.Adapter_name, subfolder=lora_args.Adapter_name)
    else:
        raise Exception(f'{resume_from_checkpoint} is not a correct path!')

In [29]:
logger.info(f"start training from {resume_from_checkpoint}")

[32m2023-07-17 06:36:01.005[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mstart training from None[0m


In [None]:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

In [None]:
logger.info(f"training finished, save model to {finetune_args.final_model_path}")

In [None]:
trainer.model.save_pretrained(finetune_args.final_model_path)

In [None]:
torch.save(trainer.args, os.path.join(finetune_args.final_model_path, "training_args.bin"))