diff --git a/xtuner/tools/chat.py b/xtuner/tools/chat.py index 002dc62ca..10b68faac 100644 --- a/xtuner/tools/chat.py +++ b/xtuner/tools/chat.py @@ -49,6 +49,11 @@ def parse_args(): '--no-streamer', action='store_true', help='Whether to with streamer') parser.add_argument('--command-stop-word', default=None, help='Stop key') parser.add_argument('--answer-stop-word', default=None, help='Stop key') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') parser.add_argument( '--max-new-tokens', type=int, @@ -138,11 +143,13 @@ def main(): quantization_config=quantization_config, load_in_8bit=load_in_8bit, device_map='auto', + offload_folder=args.offload_folder, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, trust_remote_code=True) if args.adapter is not None: - model = PeftModel.from_pretrained(model, args.adapter) + model = PeftModel.from_pretrained( + model, args.adapter, offload_folder=args.offload_folder) print(f'Load adapter from {args.adapter}') model.eval() diff --git a/xtuner/tools/model_converters/merge.py b/xtuner/tools/model_converters/merge.py index 06fdf5908..cd0a58970 100644 --- a/xtuner/tools/model_converters/merge.py +++ b/xtuner/tools/model_converters/merge.py @@ -19,6 +19,11 @@ def parse_args(): default='2GB', help='Only applicable for LLM. The maximum size for ' 'each sharded checkpoint.') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where ' + 'the model weights are already offloaded).') args = parser.parse_args() return args @@ -30,6 +35,7 @@ def main(): torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map='auto', + offload_folder=args.offload_folder, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, trust_remote_code=True) @@ -38,6 +44,7 @@ def main(): args.adapter_name_or_path, device_map='auto', torch_dtype=torch.float16, + offload_folder=args.offload_folder, is_trainable=False) model_merged = model_unmerged.merge_and_unload() print(f'Saving to {args.save_dir}...')