# Copyright (c) OpenMMLab. All rights reserved. import torch from datasets import load_dataset from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook) from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR from torch.optim import AdamW from transformers import AutoModelForCausalLM, AutoTokenizer from xtuner.dataset import ConcatDataset, process_hf_dataset from xtuner.dataset.collate_fns import default_collate_fn from xtuner.engine.hooks import (DatasetInfoHook, ThroughputHook, EvaluateChatHook, VarlenAttnArgsToMessageHubHook) from xtuner.engine.runner import TrainLoop from xtuner.model import SupervisedFinetune from xtuner.parallel.sequence import SequenceParallelSampler from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE from xtuner.dataset.map_fns import (alpaca_map_fn, alpaca_zh_map_fn, template_map_fn_factory) ####################################################################### # PART 1 Settings # ####################################################################### # Model pretrained_model_name_or_path = '/mnt/petrelfs/share_data/caoweihan/Yi-34B-200K' use_varlen_attn = False sequence_parallel_size = 2 # Data alpaca_zh_path = 'silk-road/alpaca-data-gpt4-chinese' alpaca_en_path = 'tatsu-lab/alpaca' prompt_template = PROMPT_TEMPLATE.default max_length = 8192 pack_to_max_length = True # Scheduler & Optimizer batch_size = 1 # per_device accumulative_counts = 2 dataloader_num_workers = 4 max_epochs = 3 optim_type = AdamW lr = 2e-5 betas = (0.9, 0.999) weight_decay = 0 max_norm = 1 # grad clip warmup_ratio = 0.05 # Save save_steps = 500 save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) # Evaluate the generation performance during the training evaluation_freq = 10 SYSTEM = SYSTEM_TEMPLATE.alpaca evaluation_inputs = [ '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' ] ####################################################################### # PART 2 Model & Tokenizer # ####################################################################### tokenizer = dict( type=AutoTokenizer.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, padding_side='right') model = dict( type=SupervisedFinetune, use_varlen_attn=use_varlen_attn, llm=dict( type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation='flash_attention_2')) ####################################################################### # PART 3 Dataset & Dataloader # ####################################################################### alpaca_en = dict( type=process_hf_dataset, dataset=dict(type=load_dataset, path=alpaca_en_path), tokenizer=tokenizer, max_length=max_length, dataset_map_fn=alpaca_map_fn, template_map_fn=dict( type=template_map_fn_factory, template=prompt_template), remove_unused_columns=True, shuffle_before_pack=False, pack_to_max_length=pack_to_max_length, use_varlen_attn=use_varlen_attn) alpaca_zh = dict( type=process_hf_dataset, dataset=dict(type=load_dataset, path=alpaca_zh_path), tokenizer=tokenizer, max_length=max_length, dataset_map_fn=alpaca_zh_map_fn, template_map_fn=dict( type=template_map_fn_factory, template=prompt_template), remove_unused_columns=True, shuffle_before_pack=False, pack_to_max_length=pack_to_max_length, use_varlen_attn=use_varlen_attn) train_dataset = dict(type=ConcatDataset, datasets=[alpaca_en, alpaca_zh]) train_dataloader = dict( batch_size=batch_size, num_workers=dataloader_num_workers, dataset=train_dataset, sampler=dict(type=SequenceParallelSampler, seed=1024), collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) ####################################################################### # PART 4 Scheduler & Optimizer # ####################################################################### # optimizer optim_wrapper = dict( type=AmpOptimWrapper, optimizer=dict( type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), accumulative_counts=accumulative_counts, loss_scale='dynamic') # learning policy # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 param_scheduler = [ # dict( # type=LinearLR, # start_factor=1 / 40, # by_epoch=True, # begin=0, # end=warmup_ratio * max_epochs, # convert_to_iter_based=True), dict( type=CosineAnnealingLR, eta_min=lr * 0.15, by_epoch=True, begin=0, end=max_epochs, convert_to_iter_based=True) ] # train, val, test setting train_cfg = dict(type=TrainLoop, max_iters=32) ####################################################################### # PART 5 Runtime # ####################################################################### # Log the dialogue periodically during the training process, optional custom_hooks = [ # dict(type=DatasetInfoHook, tokenizer=tokenizer), dict(type=ThroughputHook), # dict( # type=EvaluateChatHook, # tokenizer=tokenizer, # every_n_iters=evaluation_freq, # evaluation_inputs=evaluation_inputs, # system=SYSTEM, # max_new_tokens=100, # prompt_template=prompt_template) ] if use_varlen_attn: custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] # configure default hooks default_hooks = dict( # record the time of every iteration. timer=dict(type=IterTimerHook), # print log every 10 iterations. logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1), # enable the parameter scheduler. param_scheduler=dict(type=ParamSchedulerHook), # save checkpoint per `save_steps`. checkpoint=dict( type=CheckpointHook, save_optimizer=False, by_epoch=False, interval=-1, save_last=False, max_keep_ckpts=save_total_limit), # set sampler seed in distributed evrionment. sampler_seed=dict(type=DistSamplerSeedHook), ) # configure environment env_cfg = dict( # whether to enable cudnn benchmark cudnn_benchmark=False, # set multi process parameters mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), # set distributed parameters dist_cfg=dict(backend='nccl'), ) # set visualizer visualizer = None # set log level log_level = 'INFO' # load from which checkpoint load_from = None # whether to resume training from the loaded checkpoint resume = False # Defaults to use random seed and disable `deterministic` randomness = dict(seed=None, deterministic=False) # set log processor log_processor = dict( by_epoch=False, window_size=1)