Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added QLoRA support for Decoder transformers with tune_strategy "Normal" #613

Merged
merged 23 commits into from
Aug 13, 2023

Conversation

TensorBlast
Copy link
Contributor

I have added arguments under model_args to enable QLoRA support. Namely arguments are:

  • use_qlora: bool (default: False)
  • bits: int = (default 4, wtih choices between 4,8)
  • quant_type: str = 'nf4' default with optional choice of 'fp4'
  • double_quant: bool = (default: True)

Is model_args.use_qlora is set to 1/True, it also sets model_arga.lora to True so the entire pipeline works.

Copy link
Contributor Author

@TensorBlast TensorBlast left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why the eror Action failed with "The process '/usr/bin/git' failed with exit code 128 occurs. It seems to be a setting in the OptimalScale/LMFlow repository.

@TensorBlast TensorBlast closed this Aug 9, 2023
@shizhediao
Copy link
Contributor

shizhediao commented Aug 9, 2023

Hi,
Thanks for your contribution!
Let us check the action failure.

@shizhediao
Copy link
Contributor

Hi,
It is caused by the access of the repo. We can ignore this error.
I have reopened this PR so that we can review it.
Thanks!

@shizhediao shizhediao reopened this Aug 9, 2023
 and prepare_model_for_kbit_training if  qlora is enabled
@TensorBlast
Copy link
Contributor Author

I have further added some code to improve performance (gradient checkpointing and peft prepare_model_for_kbit_training). Need to include these...

@TensorBlast
Copy link
Contributor Author

How do I install LMFlow on colab? I have tried the attached notebook but it get's stuck at 'Running setup.py develop for lmflow'
LMFlow.ipynb.zip

@shizhediao
Copy link
Contributor

It's strange. Normally it should be ok with the attached notebook. Which notebook you are working on? could you share me the link?

@yaoguany
Copy link
Collaborator

yaoguany commented Aug 10, 2023

BUGs:

[src/lmflow/args.py ]

[line 168]
(https://github.com/ankitpasi/LMFlow/blob/293db00da85cdb6f96f2970272469c763af25185/src/lmflow/args.py#L168)
Change it to

use_qlora: bool = field(
        default=False,
        metadata={
            "help": "Whether to use qlora."},
)

line 172
Change it to

bits: int = field(
        default=4,
        metadata={"help": "The number of bits for quantization.",
                  "choices": [4 ,8],},
)

line 177
Change it to

quant_type: str = field(
        default="nf4",
        metadata={"help": "The quantization type for quantization.",
                "choices": ["nf4", "fp4"],},
)

ERRORS

May be you need to find a proper version of transformers,deepspeed,peft and bitsandbyres

Traceback (most recent call last):
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 61, in <module>
    main()
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 54, in main
    model = AutoModel.get_model(model_args)
  File "/home/yaoguanyu/qlora/LMFlow/src/lmflow/models/auto_model.py", line 16, in get_model
    return HFDecoderModel(model_args, *args, **kwargs)
  File "/home/yaoguanyu/qlora/LMFlow/src/lmflow/models/hf_decoder_model.py", line 263, in __init__
    quant_config = BitsAndBytesConfig(
TypeError: __init__() got an unexpected keyword argument 'load_in_4bit'
Traceback (most recent call last):
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 61, in <module>
    main()
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 54, in main
    model = AutoModel.get_model(model_args)
  File "/home/yaoguanyu/qlora/LMFlow/src/lmflow/models/auto_model.py", line 16, in get_model
    return HFDecoderModel(model_args, *args, **kwargs)
  File "/home/yaoguanyu/qlora/LMFlow/src/lmflow/models/hf_decoder_model.py", line 263, in __init__
    quant_config = BitsAndBytesConfig(
TypeError: __init__() got an unexpected keyword argument 'load_in_4bit'

@TensorBlast
Copy link
Contributor Author

Here is the colab link (available to everyone with link) where I am trying to install lmflow dependencies and test. I have been trying on A100 machine but burned a 100 compute units just waiting for setup.py to finish. Now trying again on V100...

https://colab.research.google.com/drive/1rcD2OnTGZ_dz8BLn49XiaKB9JEUY9Aoz?usp=sharing

@TensorBlast
Copy link
Contributor Author

BUGs:

[src/lmflow/args.py ]

[line 168] (https://github.com/ankitpasi/LMFlow/blob/293db00da85cdb6f96f2970272469c763af25185/src/lmflow/args.py#L168) Change it to

use_qlora: bool = field(
        default=False,
        metadata={
            "help": "Whether to use qlora."},
)

line 172 Change it to

bits: int = field(
        default=4,
        metadata={"help": "The number of bits for quantization.",
                  "choices": [4 ,8],},
)

line 177 Change it to

quant_type: str = field(
        default="nf4",
        metadata={"help": "The quantization type for quantization.",
                "choices": ["nf4", "fp4"],},
)

ERRORS

May be you need to find a proper version of transformers,deepspeed,peft and bitsandbyres

Traceback (most recent call last):
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 61, in <module>
    main()
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 54, in main
    model = AutoModel.get_model(model_args)
  File "/home/yaoguanyu/qlora/LMFlow/src/lmflow/models/auto_model.py", line 16, in get_model
    return HFDecoderModel(model_args, *args, **kwargs)
  File "/home/yaoguanyu/qlora/LMFlow/src/lmflow/models/hf_decoder_model.py", line 263, in __init__
    quant_config = BitsAndBytesConfig(
TypeError: __init__() got an unexpected keyword argument 'load_in_4bit'
Traceback (most recent call last):
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 61, in <module>
    main()
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 54, in main
    model = AutoModel.get_model(model_args)
  File "/home/yaoguanyu/qlora/LMFlow/src/lmflow/models/auto_model.py", line 16, in get_model
    return HFDecoderModel(model_args, *args, **kwargs)
  File "/home/yaoguanyu/qlora/LMFlow/src/lmflow/models/hf_decoder_model.py", line 263, in __init__
    quant_config = BitsAndBytesConfig(
TypeError: __init__() got an unexpected keyword argument 'load_in_4bit'

Hi
I have updated the requirements to the recommended versions from the official qlora source (https://github.com/artidoro/qlora/blob/main/requirements.txt)

I don't have a GPU machine to try with so I've been trying on colab but as you can see from my previous messages, I've been unable to install.

Would greatly appreciate your help in testing this please..
Thank you

@shizhediao
Copy link
Contributor

Hi,
Yes, @yaoguany will test it and get to you in a day.
Thanks

@TensorBlast
Copy link
Contributor Author

TensorBlast commented Aug 10, 2023 via email

@shizhediao
Copy link
Contributor

shizhediao commented Aug 10, 2023

That's nice! Thank you so much for your contribution, which means a lot to us!

@yaoguany
Copy link
Collaborator

yaoguany commented Aug 11, 2023

Errors

When using multi-gpu training it throws this error, maybe you can google it to find some solution.Single gpu is OK.
Besides, we need to update the transformers,deepspeed,peft and bitsandbyres before merge this branch.

Traceback (most recent call last):
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 61, in <module>
    main()
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 57, in main
    tuned_model = finetuner.tune(model=model, dataset=dataset)
  File "/home/yaoguanyu/qlora/LMFlow/src/lmflow/pipeline/finetuner.py", line 298, in tune
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/trainer.py", line 2654, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/trainer.py", line 2679, in compute_loss
    outputs = model(**inputs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1769, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/peft/peft_model.py", line 922, in forward
    return self.base_model(
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1076, in forward
    transformer_outputs = self.transformer(
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 843, in forward
    inputs_embeds = self.wte(input_ids)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

@TensorBlast
Copy link
Contributor Author

Errors

When using multi-gpu training it throws this error, maybe you can google it to find some solution.Single gpu is OK. Besides, we need to update the transformers,deepspeed,peft and bitsandbyres before merge this branch.

Traceback (most recent call last):
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 61, in <module>
    main()
  File "/home/yaoguanyu/qlora/LMFlow/examples/finetune.py", line 57, in main
    tuned_model = finetuner.tune(model=model, dataset=dataset)
  File "/home/yaoguanyu/qlora/LMFlow/src/lmflow/pipeline/finetuner.py", line 298, in tune
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/trainer.py", line 2654, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/trainer.py", line 2679, in compute_loss
    outputs = model(**inputs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1769, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/peft/peft_model.py", line 922, in forward
    return self.base_model(
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1076, in forward
    transformer_outputs = self.transformer(
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 843, in forward
    inputs_embeds = self.wte(input_ids)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/home/yaoguanyu/anaconda3/envs/test/lib/python3.9/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

Hi @yaoguany

Thanks so much for the test. The device_map={"":0} earlier was the issue, I suspect. I have added code to use local_rank environment variable to set the device map else leave it at "auto" which is better.
I suspect this should fix the issue.

Could you please do a quick try on multi-GPU again with the updated code? Greatly appreciate your kind support!

Thanks

@yaoguany
Copy link
Collaborator

The code runs well now, but we need to update lmflow requirements before merge this branch.
Thanks!

@TensorBlast
Copy link
Contributor Author

TensorBlast commented Aug 11, 2023

The code runs well now, but we need to update lmflow requirements before merge this branch. Thanks!

Ok, sure. Is that something I need to work on? I have updated lmflow requirements.txt to the below -

numpy==1.24.2
datasets==2.10.1
peft>=0.4.0
torch>=2.0.0
wandb==0.14.0
deepspeed==0.8.3
trl
sentencepiece==0.1.99
transformers>=4.31.0
flask
flask_cors
icetk
cpm_kernels==1.0.11
evaluate==0.4.0
scikit-learn==1.2.2
lm-eval
dill<0.3.5
bitsandbytes>=0.40.0
pydantic<=1.10.9
gradio
accelerate>=0.21.0
einops>=0.6.1
scikit-learn==1.2.2

@shizhediao
Copy link
Contributor

Hi,
The feature QLoRA depends on the most up-to-date huggingface transformers and deepspeed. We are currently working on the upgrade of these two packages, which is expected to take 1 day.
After the upgrade, we will merge this PR soon.
Thanks for your contribution!

@Dominic789654
Copy link
Contributor

Hi,

Thank you for your contribution.
I encountered a bug while training with your Qlora code, specifically when using the --save_aggregated_lora=1 flag, which is intended for merging the trained lora with the base model. The error message indicates that merging lora with an int4 model isn't possible. Could you provide a script to facilitate the merging of the lora model with the base model?

@TensorBlast
Copy link
Contributor Author

TensorBlast commented Aug 11, 2023 via email

…rained using QLoRA. The script reloads the model in torch_dtype and then calls merge_and_unload() on the peft model generated from training
@TensorBlast
Copy link
Contributor Author

Hi,

Thank you for your contribution. I encountered a bug while training with your Qlora code, specifically when using the --save_aggregated_lora=1 flag, which is intended for merging the trained lora with the base model. The error message indicates that merging lora with an int4 model isn't possible. Could you provide a script to facilitate the merging of the lora model with the base model?

Hi
I've updated the code to to enable merging of the lora adapters with the base model. Could you please help me out by testing it?

@TensorBlast
Copy link
Contributor Author

TensorBlast commented Aug 11, 2023

Question, shouldn't the following code

def save(self, dir, save_full_model=False, *args, **kwargs):
        """
        Perform generation process of the model.
    
        Parameters
        ------------
        dir :
            The directory to save model and tokenizer
            
        save_full_model : Optional.
            Whether to save full model.
        
        kwargs : Optional.
            Keyword arguments.    
        
        Returns
        ------------
        outputs :
            The generated sequence output 
        """
        self.get_tokenizer().save_pretrained(dir)
        if save_full_model and self.model_args.use_lora:
            self.backend_model_full.save_pretrained(dir)
        else:
            self.get_backend_model().save_pretrained(dir)

Be this instead?

def save(self, dir, save_full_model=False, *args, **kwargs):
        """
        Perform generation process of the model.
    
        Parameters
        ------------
        dir :
            The directory to save model and tokenizer
            
        save_full_model : Optional.
            Whether to save full model.
        
        kwargs : Optional.
            Keyword arguments.    
        
        Returns
        ------------
        outputs :
            The generated sequence output 
        """
        self.get_tokenizer().save_pretrained(dir)
        if save_full_model and self.model_args.use_lora:
            self.get_backend_model().save_pretrained(dir)
        else:
            self.backend_model_full.save_pretrained(dir)

because backend_model_full refers to the model without peft and backend_model refers to the PeftModel?

@TensorBlast
Copy link
Contributor Author

HI,
I have rented out a ML server with 2 x A100 80GB GPUs to test my code. I am happy to say that I ironed out a few bugs related to bnb_4bit_compute_dtype and have also fixed the above comment about saving aggregated lora by reloading the base model and merging the lora weights (disposing off the quantized model).

@TensorBlast
Copy link
Contributor Author

Hi,

Thank you for your contribution. I encountered a bug while training with your Qlora code, specifically when using the --save_aggregated_lora=1 flag, which is intended for merging the trained lora with the base model. The error message indicates that merging lora with an int4 model isn't possible. Could you provide a script to facilitate the merging of the lora model with the base model?

This had been done, per my previous comment. It is now able to train using QLoRA and then merge the lora adapters with the base model (tested on multi GPU setup)

@shizhediao
Copy link
Contributor

shizhediao commented Aug 12, 2023

Hi,
Thanks so much!
We have successfully upgraded the transformers and deepspeed version as planned.
Will review and merge this PR ASAP (in a day).
Thank you again!

@hendrydong
Copy link
Contributor

Question, shouldn't the following code

def save(self, dir, save_full_model=False, *args, **kwargs):
        """
        Perform generation process of the model.
    
        Parameters
        ------------
        dir :
            The directory to save model and tokenizer
            
        save_full_model : Optional.
            Whether to save full model.
        
        kwargs : Optional.
            Keyword arguments.    
        
        Returns
        ------------
        outputs :
            The generated sequence output 
        """
        self.get_tokenizer().save_pretrained(dir)
        if save_full_model and self.model_args.use_lora:
            self.backend_model_full.save_pretrained(dir)
        else:
            self.get_backend_model().save_pretrained(dir)

Be this instead?

def save(self, dir, save_full_model=False, *args, **kwargs):
        """
        Perform generation process of the model.
    
        Parameters
        ------------
        dir :
            The directory to save model and tokenizer
            
        save_full_model : Optional.
            Whether to save full model.
        
        kwargs : Optional.
            Keyword arguments.    
        
        Returns
        ------------
        outputs :
            The generated sequence output 
        """
        self.get_tokenizer().save_pretrained(dir)
        if save_full_model and self.model_args.use_lora:
            self.get_backend_model().save_pretrained(dir)
        else:
            self.backend_model_full.save_pretrained(dir)

because backend_model_full refers to the model without peft and backend_model refers to the PeftModel?

Hi, Thanks for your interests. self.get_backend_model().save_pretrained(dir) is our default saving strategy. For full model, it saves full. For lora model, it saves lora. However, sometimes, when we merge lora models, we want to save the full model rather than lora model, i.e., save_full_model and self.model_args.use_lora. Our code should also be compatible with full training without peft. That is our motivation.

@shizhediao
Copy link
Contributor

QLoRA is an important feature of large language model training. We wanted to express our deepest gratitude for your outstanding contribution to implementing QLoRA to LMFlow.

Your contributions are always welcomed and will undoubtedly help us shape a more successful future for LMFlow.

Thanks!

Copy link
Contributor

@shizhediao shizhediao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Merged.

@shizhediao shizhediao merged commit df30f11 into OptimalScale:main Aug 13, 2023
1 check failed
@TensorBlast
Copy link
Contributor Author

TensorBlast commented Aug 13, 2023

QLoRA is an important feature of large language model training. We wanted to express our deepest gratitude for your outstanding contribution to implementing QLoRA to LMFlow.

Your contributions are always welcomed and will undoubtedly help us shape a more successful future for LMFlow.

Thanks!

I thank you and wanted to express my gratitude for the opportunity to contribute and learn along the way!

I just had a follow up question regarding the below code, and not sure where to put it -

def save(self, dir, save_full_model=False, *args, **kwargs):
        """
        Perform generation process of the model.
    
        Parameters
        ------------
        dir :
            The directory to save model and tokenizer
            
        save_full_model : Optional.
            Whether to save full model.
        
        kwargs : Optional.
            Keyword arguments.    
        
        Returns
        ------------
        outputs :
            The generated sequence output 
        """
        self.get_tokenizer().save_pretrained(dir)
        if save_full_model and self.model_args.use_lora:
            self.backend_model_full.save_pretrained(dir)
        else:
            self.get_backend_model().save_pretrained(dir)

The code above saves backend_model_full if save_full_model=True, but in the code for hf_decoder_model.py (lines 279-312) below self.backend_model_full actually doesn't include the lora adapters, correct? The model which includes the lora adapters and is merged is self.backend_model (not self.backend_model_full) -


 def merge_lora_weights(self):
        if self.model_args.use_lora and not self.model_args.use_qlora:
            self.get_backend_model().merge_and_unload()
 model = AutoModelForCausalLM.from_pretrained(
                    model_args.model_name_or_path,
                    from_tf=bool(".ckpt" in model_args.model_name_or_path),
                    config=config,
                    quantization_config=quant_config if model_args.use_qlora else None,
                    cache_dir=model_args.cache_dir,
                    revision=model_args.model_revision,
                    use_auth_token=True if model_args.use_auth_token else None,
                    torch_dtype=torch_dtype,
                    device_map=device_map,
                    trust_remote_code = model_args.trust_remote_code,
                )
                if model_args.use_qlora:
                    model.gradient_checkpointing_enable()
                    model = prepare_model_for_kbit_training(model)
            else:
                model = AutoModelForCausalLM.from_config(config)
                n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
                logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
            self.backend_model_full = model
            if model_args.use_lora:
                if model_args.lora_target_modules:
                    lora_target_modules = model_args.lora_target_modules
                else:
                    lora_target_modules = None
                peft_config = LoraConfig(
                    task_type=TaskType.CAUSAL_LM,
                    inference_mode=False,
                    r=model_args.lora_r,
                    lora_alpha=model_args.lora_alpha,
                    lora_dropout=model_args.lora_dropout,
                    target_modules=lora_target_modules,
                )
                model = get_peft_model(model, peft_config)
               ...
               self.backend_model = model

Apologies for this but it's just ringing in my mind so would appreciate if these code snippets can be validated.
Thanks
Ankit

@hendrydong
Copy link
Contributor

hendrydong commented Aug 13, 2023

backend_model

Hi, Thank you for your careful investigation.

In fact, backend_model_full refers to base_model and backend_model refers base_model+lora. When we use lora merging, the lora weights are merged in to the base_model. So we save backend_model_full. But for backend_model, since the class of it is still peft model, so it would be some compatibility issues.

You can regard backend_model_full as backend_model.base_model, whose class is model rather than peft model. The parameter of backend_model_full and backend_model shares, so that would work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants