Skip to content

训练时报错 #368

@Brain2nd

Description

@Brain2nd

我自定义修改了如下脚本:

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
                            LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import alpaca_zh_map_fn, template_map_fn_factory
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE

#######################################################################
#                          PART 1  Settings                           #
#######################################################################
# Model
pretrained_model_name_or_path = '/home/tangshi/TangShi/Pku政务大模型/Models/Qwen-1_8B-Chat'

# Data
alpaca_zh_path = '/home/tangshi/TangShi/Pku政务大模型/Trainer/Tools/LsDtata/Data/pku'
prompt_template = PROMPT_TEMPLATE.qwen_chat
max_length = 2048
pack_to_max_length = True

# Scheduler & Optimizer
batch_size = 1  # per_device
accumulative_counts = 16
dataloader_num_workers = 0
max_epochs = 3
optim_type = AdamW
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1  # grad clip
warmup_ratio = 0.03

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = "你的任务是重庆市政务文书写作、政务问答 \n 你生成的问题必须包含:1、留言标题,2、留言摘要。你生成的答复内容部分必须有法律依据,且表明已审查,参照你固有的知识或者我给出的法律文献,在引用法律文件时使用《》包裹其名称。\n"
evaluation_inputs = [
    '留言标题:合川区一超市怀疑出售过期食品\n留言摘要:市民在合川区一超市发现了怀疑是过期的食品,担心会对消费者的健康造成威胁。', '留言标题:渝北区小区内共享单车难以管理\n留言摘要:渝北区某小区内共享单车聚集,严重影响居民出行,请求相关部门加强管理。'
]

#######################################################################
#                      PART 2  Model & Tokenizer                      #
#######################################################################
tokenizer = dict(
    type=AutoTokenizer.from_pretrained,
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    trust_remote_code=True,
    padding_side='right',
    eos_token='<|im_end|>')

# model = dict(
#     type=SupervisedFinetune,
#     llm=dict(
#         type=AutoModelForCausalLM.from_pretrained,
#         pretrained_model_name_or_path=pretrained_model_name_or_path,
#         trust_remote_code=True,
#         torch_dtype=torch.float16,
        # quantization_config=dict(
        #     type=BitsAndBytesConfig,
        #     load_in_4bit=True,
        #     load_in_8bit=False,
        #     llm_int8_threshold=6.0,
        #     llm_int8_has_fp16_weight=False,
        #     bnb_4bit_compute_dtype=torch.float16,
        #     bnb_4bit_use_double_quant=True,
        #     bnb_4bit_quant_type='nf4')
        # ),
    # lora=dict(
    #     type=LoraConfig,
    #     r=64,
    #     lora_alpha=16,
    #     lora_dropout=0.1,
    #     bias='none',
    #     task_type='CAUSAL_LM')
    # )
model = dict(
    type=SupervisedFinetune,
    llm=dict(
        type=AutoModelForCausalLM.from_pretrained,
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.float16,
    ),
)
#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
alpaca_zh = dict(
    type=process_hf_dataset,
    dataset=dict(type=load_dataset, path=alpaca_zh_path),
    tokenizer=tokenizer,
    max_length=max_length,
    dataset_map_fn=alpaca_zh_map_fn,
    template_map_fn=dict(
        type=template_map_fn_factory, template=prompt_template),
    remove_unused_columns=True,
    shuffle_before_pack=True,
    pack_to_max_length=pack_to_max_length)

train_dataloader = dict(
    batch_size=batch_size,
    num_workers=dataloader_num_workers,
    dataset=alpaca_zh,
    sampler=dict(type=DefaultSampler, shuffle=True),
    collate_fn=dict(type=default_collate_fn))

#######################################################################
#                    PART 4  Scheduler & Optimizer                    #
#######################################################################
# optimizer
optim_wrapper = dict(
    type=AmpOptimWrapper,
    optimizer=dict(
        type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
    clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
    accumulative_counts=accumulative_counts,
    loss_scale='dynamic',
    dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501
param_scheduler = [
    dict(
        type=LinearLR,
        start_factor=1e-5,
        by_epoch=True,
        begin=0,
        end=warmup_ratio * max_epochs,
        convert_to_iter_based=True),
    dict(
        type=CosineAnnealingLR,
        eta_min=0.0,
        by_epoch=True,
        begin=warmup_ratio * max_epochs,
        T_max=max_epochs,
        convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)

#######################################################################
#                           PART 5  Runtime                           #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
    dict(type=DatasetInfoHook, tokenizer=tokenizer),
    dict(
        type=EvaluateChatHook,
        tokenizer=tokenizer,
        every_n_iters=evaluation_freq,
        evaluation_inputs=evaluation_inputs,
        system=SYSTEM,
        prompt_template=prompt_template)
]

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type=IterTimerHook),
    # print log every 100 iterations.
    logger=dict(type=LoggerHook, interval=10),
    # enable the parameter scheduler.
    param_scheduler=dict(type=ParamSchedulerHook),
    # save checkpoint per epoch.
    checkpoint=dict(type=CheckpointHook, interval=1),
    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,
    # set multi process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

报错如下:

Traceback (most recent call last):
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/xtuner/tools/train.py", line 299, in <module>
    main()
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/xtuner/tools/train.py", line 295, in main
    runner.train()
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/mmengine/runner/_flexible_runner.py", line 1182, in train
    self.strategy.prepare(
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/mmengine/_strategy/deepspeed.py", line 389, in prepare
    self.param_schedulers = self.build_param_scheduler(
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/mmengine/_strategy/base.py", line 658, in build_param_scheduler
    param_schedulers = self._build_param_scheduler(
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/mmengine/_strategy/base.py", line 563, in _build_param_scheduler
    PARAM_SCHEDULERS.build(
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/mmengine/registry/registry.py", line 570, in build
    return self.build_func(cfg, *args, **kwargs, registry=self)
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 294, in build_scheduler_from_cfg
    return scheduler_cls.build_iter_from_epoch(  # type: ignore
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/mmengine/optim/scheduler/param_scheduler.py", line 787, in build_iter_from_epoch
    return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/mmengine/optim/scheduler/lr_scheduler.py", line 20, in __init__
    super().__init__(optimizer, 'lr', *args, **kwargs)
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/mmengine/optim/scheduler/param_scheduler.py", line 759, in __init__
    super().__init__(
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/mmengine/optim/scheduler/param_scheduler.py", line 68, in __init__
    raise ValueError('end should be larger than begin, but got'
ValueError: end should be larger than begin, but got begin=0, end=0
[2024-01-29 17:32:32,013] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 47059) of binary: /home/tangshi/miniconda3/envs/xtuner/bin/python
Traceback (most recent call last):
  File "/home/tangshi/miniconda3/envs/xtuner/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/home/tangshi/miniconda3/envs/xtuner/lib/python3.10/site-packages/xtuner/tools/train.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2024-01-29_17:32:32
  host      : tangshi
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 47060)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
  time      : 2024-01-29_17:32:32
  host      : tangshi
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 47061)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
  time      : 2024-01-29_17:32:32
  host      : tangshi
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 47062)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-01-29_17:32:32
  host      : tangshi
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 47059)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions