diff --git a/configs/alpaca/alpaca_standford_llama-7b.py b/configs/alpaca/alpaca_standford_llama-7b.py index fba1b0709..ea3e44263 100644 --- a/configs/alpaca/alpaca_standford_llama-7b.py +++ b/configs/alpaca/alpaca_standford_llama-7b.py @@ -9,11 +9,6 @@ from .._base_.schedules.guanaco import * # noqa: F401,F403 pretrained_model_name_or_path = '/nvme/share_data/llama-7b' -model = dict( - type=SupervisedFinetune, - llm=dict( - type=AutoModelForCausalLM.from_pretrained, - pretrained_model_name_or_path=pretrained_model_name_or_path)) tokenizer = dict( type=AutoTokenizer.from_pretrained, @@ -21,4 +16,11 @@ use_fast=False, padding_side='right') +model = dict( + type=SupervisedFinetune, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path), + tokenizer=tokenizer) + train_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405 diff --git a/configs/alpaca/alpaca_standford_llama-7b_deepspeed.py b/configs/alpaca/alpaca_standford_llama-7b_deepspeed.py index 6aafcbf56..05bf72158 100644 --- a/configs/alpaca/alpaca_standford_llama-7b_deepspeed.py +++ b/configs/alpaca/alpaca_standford_llama-7b_deepspeed.py @@ -9,11 +9,6 @@ from .._base_.schedules.guanaco_deepspeed import * # noqa: F401,F403 pretrained_model_name_or_path = '/nvme/share_data/llama-7b' -model = dict( - type=SupervisedFinetune, - llm=dict( - type=AutoModelForCausalLM.from_pretrained, - pretrained_model_name_or_path=pretrained_model_name_or_path)) tokenizer = dict( type=AutoTokenizer.from_pretrained, @@ -21,4 +16,11 @@ use_fast=False, padding_side='right') +model = dict( + type=SupervisedFinetune, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path), + tokenizer=tokenizer) + train_dataloader['collate_fn']['tokenizer'] = tokenizer # noqa: F405 diff --git a/configs/alpaca/alpaca_standford_llama-7b_qlora.py b/configs/alpaca/alpaca_standford_llama-7b_qlora.py index d312cfac8..02722f854 100644 --- a/configs/alpaca/alpaca_standford_llama-7b_qlora.py +++ b/configs/alpaca/alpaca_standford_llama-7b_qlora.py @@ -12,6 +12,13 @@ from .._base_.schedules.guanaco import * # noqa: F401,F403 pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b' + +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + use_fast=False, + padding_side='right') + model = dict( type=SupervisedQloraFinetune, llm=dict( @@ -31,12 +38,7 @@ lora_alpha=16, lora_dropout=0.1, bias='none', - task_type='CAUSAL_LM')) - -tokenizer = dict( - type=AutoTokenizer.from_pretrained, - pretrained_model_name_or_path=pretrained_model_name_or_path, - use_fast=False, - padding_side='right') + task_type='CAUSAL_LM'), + tokenizer=tokenizer) train_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405 diff --git a/configs/guanaco/gunaco_llama-7b_qlora.py b/configs/guanaco/gunaco_llama-7b_qlora.py index 62525b49b..1ccda85a8 100644 --- a/configs/guanaco/gunaco_llama-7b_qlora.py +++ b/configs/guanaco/gunaco_llama-7b_qlora.py @@ -5,6 +5,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig) +from mmchat.engine import LogSampleHook from mmchat.models import SupervisedQloraFinetune with read_base(): @@ -14,6 +15,13 @@ from .._base_.schedules.guanaco import * # noqa: F401,F403 pretrained_model_name_or_path = '/nvme/share_data/llama-7b' + +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + use_fast=False, + padding_side='right') + model = dict( type=SupervisedQloraFinetune, data_preprocessor=dict(type=BaseDataPreprocessor), @@ -36,13 +44,8 @@ lora_alpha=16, lora_dropout=0.1, bias='none', - task_type='CAUSAL_LM')) - -tokenizer = dict( - type=AutoTokenizer.from_pretrained, - pretrained_model_name_or_path=pretrained_model_name_or_path, - use_fast=False, - padding_side='right') + task_type='CAUSAL_LM'), + tokenizer=tokenizer) train_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405 val_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405 @@ -50,3 +53,8 @@ val_evaluator['tokenizer'] = tokenizer # noqa: F405 test_evaluator['tokenizer'] = tokenizer # noqa: F405 + +custom_hooks = [dict( + type=LogSampleHook, + tokenizer=tokenizer, +)] diff --git a/configs/guanaco/gunaco_llama-7b_qlora_deepspeed.py b/configs/guanaco/gunaco_llama-7b_qlora_deepspeed.py index e1985ff22..1370e54f2 100644 --- a/configs/guanaco/gunaco_llama-7b_qlora_deepspeed.py +++ b/configs/guanaco/gunaco_llama-7b_qlora_deepspeed.py @@ -14,6 +14,13 @@ from .._base_.schedules.guanaco_deepspeed import * # noqa: F401,F403 pretrained_model_name_or_path = '/nvme/share_data/llama-7b' + +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + use_fast=False, + padding_side='right') + model = dict( type=SupervisedQloraFinetune, data_preprocessor=dict(type=BaseDataPreprocessor), @@ -36,17 +43,12 @@ lora_alpha=16, lora_dropout=0.1, bias='none', - task_type='CAUSAL_LM')) - -tokenizer = dict( - type=AutoTokenizer.from_pretrained, - pretrained_model_name_or_path=pretrained_model_name_or_path, - use_fast=False, - padding_side='right') + task_type='CAUSAL_LM'), + tokenizer=tokenizer) train_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405 val_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405 test_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405 val_evaluator['tokenizer'] = tokenizer # noqa: F405 -test_evaluator['tokenizer'] = tokenizer # noqa: F405 \ No newline at end of file +test_evaluator['tokenizer'] = tokenizer # noqa: F405 diff --git a/mmchat/datasets/huggingface.py b/mmchat/datasets/huggingface.py index 2991f9bf0..5487aae68 100644 --- a/mmchat/datasets/huggingface.py +++ b/mmchat/datasets/huggingface.py @@ -1,5 +1,7 @@ from functools import partial +from mmengine.config.lazy import LazyObject + from mmchat.registry import DATASETS, TOKENIZER from .utils import Concatenator, encode_fn @@ -17,7 +19,16 @@ def process_hf_dataset(dataset, dataset = DATASETS.build(dataset) if isinstance(map_fn, str): map_fn = eval(map_fn) - dataset = dataset.map(map_fn, remove_columns=remove_columns) + if isinstance(map_fn, list): + assert all( + [callable(fn) and isinstance(fn, LazyObject) for fn in map_fn]) + for fn in map_fn[:-1]: + fn = fn.build() + dataset = dataset.map(fn) + dataset = dataset.map( + map_fn[-1].build(), remove_columns=remove_columns) + elif map_fn is not None: + dataset = dataset.map(map_fn, remove_columns=remove_columns) for old, new in rename_maps: dataset = dataset.rename_column(old, new) tokenizer = TOKENIZER.build(tokenizer) diff --git a/mmchat/datasets/map_fns/__init__.py b/mmchat/datasets/map_fns/__init__.py index f04a57ae3..eb0b63491 100644 --- a/mmchat/datasets/map_fns/__init__.py +++ b/mmchat/datasets/map_fns/__init__.py @@ -1,5 +1,7 @@ -from .alpaca_map_fn import alpaca_map_fn -from .alpaca_zh_map_fn import alpaca_zh_map_fn -from .oasst1_map_fn import oasst1_map_fn +from .dataset_map_fn import alpaca_map_fn, alpaca_zh_map_fn, oasst1_map_fn +from .model_map_fn import internlm_map_fn, llama2_map_fn -__all__ = ['alpaca_map_fn', 'alpaca_zh_map_fn', 'oasst1_map_fn'] +__all__ = [ + 'alpaca_map_fn', 'alpaca_zh_map_fn', 'oasst1_map_fn', 'internlm_map_fn', + 'llama2_map_fn' +] diff --git a/mmchat/datasets/map_fns/dataset_map_fn/__init__.py b/mmchat/datasets/map_fns/dataset_map_fn/__init__.py new file mode 100644 index 000000000..f04a57ae3 --- /dev/null +++ b/mmchat/datasets/map_fns/dataset_map_fn/__init__.py @@ -0,0 +1,5 @@ +from .alpaca_map_fn import alpaca_map_fn +from .alpaca_zh_map_fn import alpaca_zh_map_fn +from .oasst1_map_fn import oasst1_map_fn + +__all__ = ['alpaca_map_fn', 'alpaca_zh_map_fn', 'oasst1_map_fn'] diff --git a/mmchat/datasets/map_fns/alpaca_map_fn.py b/mmchat/datasets/map_fns/dataset_map_fn/alpaca_map_fn.py similarity index 100% rename from mmchat/datasets/map_fns/alpaca_map_fn.py rename to mmchat/datasets/map_fns/dataset_map_fn/alpaca_map_fn.py diff --git a/mmchat/datasets/map_fns/alpaca_zh_map_fn.py b/mmchat/datasets/map_fns/dataset_map_fn/alpaca_zh_map_fn.py similarity index 100% rename from mmchat/datasets/map_fns/alpaca_zh_map_fn.py rename to mmchat/datasets/map_fns/dataset_map_fn/alpaca_zh_map_fn.py diff --git a/mmchat/datasets/map_fns/oasst1_map_fn.py b/mmchat/datasets/map_fns/dataset_map_fn/oasst1_map_fn.py similarity index 100% rename from mmchat/datasets/map_fns/oasst1_map_fn.py rename to mmchat/datasets/map_fns/dataset_map_fn/oasst1_map_fn.py diff --git a/mmchat/datasets/map_fns/model_map_fn/__init__.py b/mmchat/datasets/map_fns/model_map_fn/__init__.py new file mode 100644 index 000000000..6cbbfd9ec --- /dev/null +++ b/mmchat/datasets/map_fns/model_map_fn/__init__.py @@ -0,0 +1,4 @@ +from .internlm_map_fn import internlm_map_fn +from .llama2_map_fn import llama2_map_fn + +__all__ = ['internlm_map_fn', 'llama2_map_fn'] diff --git a/mmchat/datasets/map_fns/model_map_fn/internlm_map_fn.py b/mmchat/datasets/map_fns/model_map_fn/internlm_map_fn.py new file mode 100644 index 000000000..618489f59 --- /dev/null +++ b/mmchat/datasets/map_fns/model_map_fn/internlm_map_fn.py @@ -0,0 +1,8 @@ +def internlm_map_fn(example): + user = '<|User|>' + eoh = '' + eoa = '' # noqa:F841 + assistant = '<|Bot|>' + instruction = example.get('input', '') + prompt = f'{user}:{instruction}{eoh}\n{assistant}:' + return {'input': prompt} diff --git a/mmchat/datasets/map_fns/model_map_fn/llama2_map_fn.py b/mmchat/datasets/map_fns/model_map_fn/llama2_map_fn.py new file mode 100644 index 000000000..654f5d563 --- /dev/null +++ b/mmchat/datasets/map_fns/model_map_fn/llama2_map_fn.py @@ -0,0 +1,15 @@ +def llama2_map_fn(example): + B_INST, E_INST = '[INST]', '[/INST]' + B_SYS, E_SYS = '<>\n', '\n<>\n\n' + + DEFAULT_SYSTEM_PROMPT = \ + 'You are a helpful, respectful and honest assistant. Always answer ' \ + 'as helpfully as possible, while being safe. Your answers should not' \ + ' include any harmful, unethical, racist, sexist, toxic, dangerous, ' \ + 'or illegal content. Please ensure that your responses are socially ' \ + 'unbiased and positive in nature.' + + instruction = example.get('input', '') + prompt = f'{B_INST} {B_SYS} {DEFAULT_SYSTEM_PROMPT} {E_SYS}' \ + f'{instruction} {E_INST}' + return {'input': prompt} diff --git a/mmchat/engine/__init__.py b/mmchat/engine/__init__.py index bac868b1b..5a13bb6c0 100644 --- a/mmchat/engine/__init__.py +++ b/mmchat/engine/__init__.py @@ -1,3 +1,3 @@ -from .hooks import SampleGenerateHook +from .hooks import LogSampleHook, SampleGenerateHook -__all__ = ['SampleGenerateHook'] +__all__ = ['SampleGenerateHook', 'LogSampleHook'] diff --git a/mmchat/engine/hooks/__init__.py b/mmchat/engine/hooks/__init__.py index 6ab19717b..498921eb7 100644 --- a/mmchat/engine/hooks/__init__.py +++ b/mmchat/engine/hooks/__init__.py @@ -1,3 +1,4 @@ +from .log_data_sample import LogSampleHook from .sample_generate_hook import SampleGenerateHook -__all__ = ['SampleGenerateHook'] +__all__ = ['SampleGenerateHook', 'LogSampleHook'] diff --git a/mmchat/engine/hooks/log_data_sample.py b/mmchat/engine/hooks/log_data_sample.py new file mode 100644 index 000000000..9276897c3 --- /dev/null +++ b/mmchat/engine/hooks/log_data_sample.py @@ -0,0 +1,29 @@ +from mmengine.hooks import Hook + +from mmchat.registry import HOOKS, TOKENIZER + + +@HOOKS.register_module() +class LogSampleHook(Hook): + + def __init__(self, tokenizer): + self.tokenizer = TOKENIZER.build(tokenizer) + + def log(self, runner, dataset, mode='train'): + runner.logger.info(f'Num {mode} samples {len(dataset)}') + runner.logger.info(f'{mode} example:') + runner.logger.info(self.tokenizer.decode(dataset[0]['input_ids'])) + + def before_run(self, runner) -> None: + do_train = runner.train_loop is not None + do_eval = runner.val_loop is not None + do_test = runner.test_loop is not None + if do_train: + train_dataset = runner.train_dataloader.dataset + self.log(runner, train_dataset, mode='train') + if do_eval: + eval_dataset = runner.val_dataloader.dataset + self.log(runner, eval_dataset, mode='eval') + if do_test: + test_dataset = runner.test_dataloader.dataset + self.log(runner, test_dataset, mode='test') diff --git a/mmchat/models/algorithms/sft.py b/mmchat/models/algorithms/sft.py index 267915617..5ea7d4141 100644 --- a/mmchat/models/algorithms/sft.py +++ b/mmchat/models/algorithms/sft.py @@ -1,5 +1,4 @@ import dataclasses -from typing import Dict import torch import transformers @@ -7,7 +6,7 @@ from mmengine.model import BaseModel from torch import nn -from mmchat.registry import LLM +from mmchat.registry import LLM, TOKENIZER def traverse_dict(d): @@ -28,7 +27,6 @@ def traverse_dict(d): def smart_tokenizer_and_embedding_resize( - special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): @@ -37,8 +35,9 @@ def smart_tokenizer_and_embedding_resize( Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ - num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model_vocab_size = model.get_output_embeddings().weight.size(0) model.resize_token_embeddings(len(tokenizer)) + num_new_tokens = len(tokenizer) - model_vocab_size if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data @@ -51,16 +50,20 @@ def smart_tokenizer_and_embedding_resize( input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg + elif num_new_tokens < 0: + raise RuntimeError class SupervisedFinetune(BaseModel): - def __init__(self, llm, data_preprocessor=None): + def __init__(self, llm, data_preprocessor=None, tokenizer=None): super().__init__(data_preprocessor) self.llm = self._build_from_cfg_or_module(llm, LLM) self.llm.config.use_cache = False self.llm.config.torch_dtype = torch.float32 - + tokenizer = TOKENIZER.build(tokenizer) + smart_tokenizer_and_embedding_resize(tokenizer, self.llm) + def _build_from_cfg_or_module(self, cfg_or_mod, registry): if isinstance(cfg_or_mod, nn.Module): return cfg_or_mod @@ -97,7 +100,7 @@ def compute_loss(self, data, data_samples=None): # import pdb;pdb.set_trace() loss_dict = {'loss': outputs.loss} return loss_dict - + def __getattr__(self, name: str): try: return super().__getattr__(name) diff --git a/mmchat/models/algorithms/sft_lora.py b/mmchat/models/algorithms/sft_lora.py index 39f40ee16..070932158 100644 --- a/mmchat/models/algorithms/sft_lora.py +++ b/mmchat/models/algorithms/sft_lora.py @@ -1,7 +1,142 @@ +from collections import OrderedDict + +import torch +from mmengine.runner import load_checkpoint +from peft import (PeftType, PromptLearningConfig, get_peft_model, + prepare_model_for_kbit_training) + +from mmchat.registry import MODELS from .sft import SupervisedFinetune +def find_all_linear_names(model): + lora_module_names = set() + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + class SupervisedLoraFinetune(SupervisedFinetune): - def __init__(self, llm, tokenizer, lora): - super().__init__(llm, tokenizer) + def __init__(self, + llm, + lora, + data_preprocessor=None, + tokenizer=None, + peft_model=None): + super().__init__(llm, data_preprocessor, tokenizer) + + self.llm = prepare_model_for_kbit_training(self.llm) + + lora = MODELS.build(lora) + if lora.target_modules is None: + modules = find_all_linear_names(self.llm) + lora.target_modules = modules + + self.llm = get_peft_model(self.llm, lora) + if peft_model is not None: + _ = load_checkpoint(self, peft_model) + + for name, module in self.llm.named_modules(): + if 'norm' in name: + module = module.to(torch.float32) + self._is_init = True + + def init_weights(self): + pass + + def state_dict(self, destination=None, prefix='', keep_vars=False): + + def get_peft_model_state_dict(model, + state_dict=None, + adapter_name='default'): + # Modified from `https://github.com/huggingface/peft/blob/main/src + # /peft/utils/save_and_load.py` + + config = model.peft_config[adapter_name] + if state_dict is None: + state_dict = model.state_dict() + if config.peft_type in (PeftType.LORA, PeftType.ADALORA): + # to_return = lora_state_dict(model, + # bias=model.peft_config.bias) + # adapted from `https://github.com/microsoft/LoRA/blob/main/ + # loralib/utils.py` + # to be used directly with the state dict which is necessary + # when using DeepSpeed or FSDP + bias = config.bias + if bias == 'none': + to_return = { + k: state_dict[k] + for k in state_dict if 'lora_' in k + } + elif bias == 'all': + to_return = { + k: state_dict[k] + for k in state_dict if 'lora_' in k or 'bias' in k + } + elif bias == 'lora_only': + to_return = {} + for k in state_dict: + if 'lora_' in k: + to_return[k] = state_dict[k] + bias_name = k.split('lora_')[0] + 'bias' + if bias_name in state_dict: + to_return[bias_name] = state_dict[bias_name] + else: + raise NotImplementedError + to_return = { + k: v + for k, v in to_return.items() + if (('lora_' in k and adapter_name in k) or ('bias' in k)) + } + if config.peft_type == PeftType.ADALORA: + # todo + raise NotImplementedError + # rank_pattern = config.rank_pattern + # if rank_pattern is not None: + # rank_pattern = { + # k.replace(f'.{adapter_name}', ''): v + # for k, v in rank_pattern.items() + # } + # config.rank_pattern = rank_pattern + # to_return = model.resize_state_dict_by_rank_pattern( + # rank_pattern, to_return, adapter_name) + + elif config.peft_type == PeftType.ADAPTION_PROMPT: + to_return = { + k: state_dict[k] + for k in state_dict + if k.split('.')[-1].startswith('adaption_') + } + elif isinstance(config, PromptLearningConfig): + to_return = {} + if config.inference_mode: + prompt_embeddings = model.prompt_encoder[ + adapter_name].embedding.weight + else: + prompt_embeddings = model.get_prompt_embedding_to_save( + adapter_name) + to_return['prompt_embeddings'] = prompt_embeddings + elif config.peft_type == PeftType.IA3: + to_return = { + k: state_dict[k] + for k in state_dict if 'ia3_' in k + } + else: + raise NotImplementedError + if model.modules_to_save is not None: + for key, value in state_dict.items(): + if any(f'{module_name}.modules_to_save.{adapter_name}' in + key for module_name in model.modules_to_save): + to_return[key] = value + + return to_return + + state_dict = super().state_dict() + to_return = get_peft_model_state_dict(self.llm, state_dict=state_dict) + return OrderedDict(to_return) diff --git a/mmchat/models/algorithms/sft_qlora.py b/mmchat/models/algorithms/sft_qlora.py index e577feb46..fa503a82c 100644 --- a/mmchat/models/algorithms/sft_qlora.py +++ b/mmchat/models/algorithms/sft_qlora.py @@ -1,148 +1,13 @@ -from collections import OrderedDict +from .sft_lora import SupervisedLoraFinetune -import bitsandbytes as bnb -import torch -import torch.nn as nn -from peft import (PeftType, PromptLearningConfig, PeftConfig, get_peft_model, - prepare_model_for_kbit_training) -from mmchat.registry import MODELS -from .sft import SupervisedFinetune +# todo: check if class `SupervisedQloraFinetune` is necessary +class SupervisedQloraFinetune(SupervisedLoraFinetune): - -def find_all_linear_names(model): - cls = bnb.nn.Linear4bit - # cls = nn.Linear - lora_module_names = set() - for name, module in model.named_modules(): - if isinstance(module, cls): - names = name.split('.') - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - if 'lm_head' in lora_module_names: # needed for 16-bit - lora_module_names.remove('lm_head') - return list(lora_module_names) - - -class SupervisedQloraFinetune(SupervisedFinetune): - - def __init__(self, llm, lora, data_preprocessor=None): - super().__init__(llm, data_preprocessor) - - self.llm = prepare_model_for_kbit_training(self.llm) - - modules = find_all_linear_names(self.llm) - - if isinstance(lora, PeftConfig): - lora = lora - elif isinstance(lora, dict): - lora = MODELS.build(lora) - else: - raise NotImplementedError - - lora.target_modules = modules - - self.llm = get_peft_model(self.llm, lora) - - for name, module in self.llm.named_modules(): - # todo - # if isinstance(module, LoraLayer): - # module = module.to(torch.bfloat16) - if 'norm' in name: - module = module.to(torch.float32) - # if 'lm_head' in name or 'embed_tokens' in name: - # if hasattr(module, 'weight'): - # if module.weight.dtype == torch.float32: - # module = module.to(torch.float16) - self._is_init = True - - def init_weights(self): - pass - - def __getattr__(self, name: str): - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.llm, name) - - def state_dict(self, destination=None, prefix='', keep_vars=False): - - def get_peft_model_state_dict(model, - state_dict=None, - adapter_name='default'): - # Modified from `https://github.com/huggingface/peft/blob/main/src - # /peft/utils/save_and_load.py` - - config = model.peft_config[adapter_name] - if state_dict is None: - state_dict = model.state_dict() - if config.peft_type in (PeftType.LORA, PeftType.ADALORA): - # to_return = lora_state_dict(model, - # bias=model.peft_config.bias) - # adapted from `https://github.com/microsoft/LoRA/blob/main/ - # loralib/utils.py` - # to be used directly with the state dict which is necessary - # when using DeepSpeed or FSDP - bias = config.bias - if bias == 'none': - to_return = { - k: state_dict[k] - for k in state_dict if 'lora_' in k - } - elif bias == 'all': - to_return = { - k: state_dict[k] - for k in state_dict if 'lora_' in k or 'bias' in k - } - elif bias == 'lora_only': - to_return = {} - for k in state_dict: - if 'lora_' in k: - to_return[k] = state_dict[k] - bias_name = k.split('lora_')[0] + 'bias' - if bias_name in state_dict: - to_return[bias_name] = state_dict[bias_name] - else: - raise NotImplementedError - to_return = { - k: v - for k, v in to_return.items() - if (('lora_' in k and adapter_name in k) or ('bias' in k)) - } - if config.peft_type == PeftType.ADALORA: - rank_pattern = config.rank_pattern - if rank_pattern is not None: - rank_pattern = { - k.replace(f'.{adapter_name}', ''): v - for k, v in rank_pattern.items() - } - config.rank_pattern = rank_pattern - - elif config.peft_type == PeftType.ADAPTION_PROMPT: - to_return = { - k: state_dict[k] - for k in state_dict - if k.split('.')[-1].startswith('adaption_') - } - elif isinstance(config, PromptLearningConfig): - to_return = {} - if config.inference_mode: - prompt_embeddings = model.prompt_encoder[ - adapter_name].embedding.weight - else: - prompt_embeddings = model.get_prompt_embedding_to_save( - adapter_name) - to_return['prompt_embeddings'] = prompt_embeddings - else: - raise NotImplementedError - if model.modules_to_save is not None: - for key, value in state_dict.items(): - if any(f'{module_name}.modules_to_save.{adapter_name}' in - key for module_name in model.modules_to_save): - to_return[key.replace('modules_to_save.', '')] = value - - return to_return - - state_dict = super().state_dict() - to_return = get_peft_model_state_dict(self.llm, state_dict=state_dict) - return OrderedDict(to_return) + def __init__(self, + llm, + lora, + data_preprocessor=None, + tokenizer=None, + peft_path=None): + super().__init__(llm, lora, data_preprocessor, tokenizer, peft_path) diff --git a/tools/test.py b/tools/test.py index 9b783dd43..5d2fc3f53 100644 --- a/tools/test.py +++ b/tools/test.py @@ -13,7 +13,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='MMChat test a model') parser.add_argument('config', help='test config file path') - parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--checkpoint', help='checkpoint file', default=None) parser.add_argument( '--work-dir', help='the directory to save the file containing evaluation metrics') @@ -60,7 +60,8 @@ def main(): cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) - cfg.load_from = args.checkpoint + if args.checkpoint is not None: + cfg.load_from = args.checkpoint # build the runner from config if 'runner_type' not in cfg: