Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT initialization fix #361

Merged
merged 4 commits into from
Oct 27, 2023

Conversation

alex4321
Copy link
Contributor

@alex4321 alex4321 commented Sep 29, 2023

Problem: with the latest peft & autogptq_versions - using get_gptq_peft_model does not replace GPTQ quantized linear layers with correct LoRA wrappers (GPTQLoraLinear / GPTQSVDLinear instead of peft's built-in QuantLinear), leading to issues with forward pass calculation data types and bad LoRA initialization values leading to NaN losses.

Detailed:

In this library, we have the following function:

def get_gptq_peft_model(
    model: BaseGPTQForCausalLM,
    peft_config: PeftConfig = None,
    model_id: str = None,
    adapter_name: str = "default",
    auto_find_all_linears: bool = True,
    train_mode: bool = False
):

which do some hijack on top of peft's get_peft_model(model.model, peft_config)

However, inside the code, we need to replace auto_gptq's quant linear layers with custom LoRA layers - such as GPTQLoraLinear / GPTQSVDLinear.

I see the corresponding code in GPTQLoraModel::_find_and_replace:

...
                        new_module = GPTQLoraLinear(adapter_name, target, **kwargs)

                    self._replace_module(parent, target_name, new_module, target)
...

and the same method for GPTQAdaLoraModel.

However, the latest versions of peft use another system of LoRA modules initialization (peft's LoraModel::_create_and_replace):

        else:
            new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
            self._replace_module(parent, target_name, new_module, target)

Where _create_new_module has code like next:

        gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
        AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
...
        elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
            new_module = QuantLinear(adapter_name, target, **kwargs)
            target.weight = target.qweight

The problem is:

  • there are some issues with getting the correct AutoGPTQQuantLinear class (so loaded module contains GeneralQuantLinear's, for instance, but AutoGPTQQuantLinear is QuantLinearCuda)
  • using QuantLinear instead of GPTQSVDLinear / GPTQLoraLinear leading to type issues during forward pass, at least for fp16 computations
  • even after solving that - QuantLinear.reset_lora_parameters uses LoraLayer.reset_lora_parameters, which in the case of GPTQ Llama models - leading to such initialization values they cause NaN losses:
    • Like:
      def reset_lora_parameters(self, adapter_name):
          if adapter_name in self.lora_A.keys():
              # initialize A the same way as the default for nn.Linear and B to zero
              nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
              nn.init.zeros_(self.lora_B[adapter_name].weight)
          if adapter_name in self.lora_embedding_A.keys():
              # initialize a the same way as the default for nn.linear and b to zero
              nn.init.zeros_(self.lora_embedding_A[adapter_name])
              nn.init.normal_(self.lora_embedding_B[adapter_name])
    Instead of GPTQLoraLinear's:
      def reset_lora_parameters(self, adapter_name):
          if adapter_name in self.lora_A.keys():
              torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight)
              torch.nn.init.zeros_(self.lora_B[adapter_name].weight)

What does this patch do:

  • create appropriate _create_new_module method inside GPTQLoraModel and GPTQAdaLoraModel
  • create autotests for both of these PEFT types, ensuring that:
    • correct LoRA wrappers were used at all
    • the result model is trainable (reducing loss, no NaN's and so on).

Copy link
Collaborator

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @alex4321 for submitting the PR, and thank you especially for implementing tests that are passing.

LGTM!

@fxmarty fxmarty merged commit a7d61ca into AutoGPTQ:main Oct 27, 2023
@vivekkhandelwal1
Copy link
Contributor

@alex4321, these changes result in an error, for the following code:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, device_map="cpu", quantization_config=quantization_config, torch_dtype=torch.float32)

inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))

@fxmarty
Copy link
Collaborator

fxmarty commented Nov 1, 2023

Thank you @vivekkhandelwal1, will have a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants