Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
bitsandbytes>=0.40.0
datasets
einops
mmengine>=0.8.4
mmengine>=0.9.0
modelscope
peft>=0.4.0
scipy
Expand Down
17 changes: 12 additions & 5 deletions xtuner/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dataclasses

import torch
from mmengine import print_log
from mmengine.utils.misc import get_object_from_string
from torch import nn


def set_obj_dtype(d):
for key, value in d.items():
if value in ['torch.float16', 'torch.float32', 'torch.bfloat16']:
d[key] = getattr(torch, value.split('.')[-1])


def traverse_dict(d):
if isinstance(d, dict):
set_obj_dtype(d)
for key, value in d.items():
if isinstance(value, dict):
if 'type' in value and dataclasses.is_dataclass(value['type']):
traverse_dict(value)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续遍历可能会增加时间复杂度,虽然traverse的耗时非常少

Copy link
Copy Markdown
Contributor Author

@LZHgrla LZHgrla Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原版代码只做了“一层”嵌套的处理,这个PR主要是为了涵盖class build和torch.float16嵌套“非常深”的情况,要递归到最深层依次做处理。实际耗时基本无感。

以下面为例

  llm=dict(
      pretrained_model_name_or_path='internlm/internlm-7b',
      quantization_config=dict(
          bnb_4bit_compute_dtype='torch.float16',
          bnb_4bit_quant_type='nf4',
          bnb_4bit_use_double_quant=True,
          llm_int8_has_fp16_weight=False,
          llm_int8_threshold=6.0,
          load_in_4bit=True,
          load_in_8bit=False,
          type='transformers.BitsAndBytesConfig'),
      torch_dtype='torch.float16',
      trust_remote_code=True,
      type='transformers.AutoModelForCausalLM.from_pretrained')

在原版代码,torch_dtype='torch.float16',会被处理为torch.float16;但再深一层的bnb_4bit_compute_dtype='torch.float16',,这一行会缺失处理,而导致报错

if 'type' in value:
builder = value.pop('type')
if isinstance(builder, str):
builder = get_object_from_string(builder)
new_value = builder(**value)
d[key] = new_value
print_log(f'{key} convert to {builder}')
else:
traverse_dict(value)
elif isinstance(d, list):
for element in d:
traverse_dict(element)
Expand Down
5 changes: 1 addition & 4 deletions xtuner/tools/check_custom_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ def parse_args():
parser = argparse.ArgumentParser(
description='Verify the correctness of the config file for the '
'custom dataset.')
parser.add_argument(
'config',
help='config file name or path. Note: Please use the original '
'configs, instead of the automatically saved log configs.')
parser.add_argument('config', help='config file name or path.')
args = parser.parse_args()
return args

Expand Down
5 changes: 1 addition & 4 deletions xtuner/tools/log_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@

def parse_args():
parser = argparse.ArgumentParser(description='Log processed dataset.')
parser.add_argument(
'config',
help='config file name or path. Note: Please use the original '
'configs, instead of the automatically saved log configs.')
parser.add_argument('config', help='config file name or path.')
args = parser.parse_args()
return args

Expand Down
5 changes: 1 addition & 4 deletions xtuner/tools/model_converters/pth_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
def parse_args():
parser = argparse.ArgumentParser(
description='Convert the pth model to HuggingFace model')
parser.add_argument(
'config',
help='config file name or path. Note: Please use the original '
'configs, instead of the automatically saved log configs.')
parser.add_argument('config', help='config file name or path.')
parser.add_argument('pth_model', help='pth model file')
parser.add_argument(
'save_dir', help='the directory to save HuggingFace model')
Expand Down
5 changes: 1 addition & 4 deletions xtuner/tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@

def parse_args():
parser = argparse.ArgumentParser(description='Test model')
parser.add_argument(
'config',
help='config file name or path. Note: Please use the original '
'configs, instead of the automatically saved log configs.')
parser.add_argument('config', help='config file name or path.')
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
parser.add_argument(
'--work-dir',
Expand Down
5 changes: 1 addition & 4 deletions xtuner/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@

def parse_args():
parser = argparse.ArgumentParser(description='Train LLM')
parser.add_argument(
'config',
help='config file name or path. Note: Please use the original '
'configs, instead of the automatically saved log configs.')
parser.add_argument('config', help='config file name or path.')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--deepspeed',
Expand Down