diff --git a/docs/source/developer_guides/lora.md b/docs/source/developer_guides/lora.md index 88dd6547cc..9ba53839e0 100644 --- a/docs/source/developer_guides/lora.md +++ b/docs/source/developer_guides/lora.md @@ -69,6 +69,16 @@ from peft import LoraConfig config = LoraConfig(use_rslora=True, ...) ``` +### Weight-Decomposed Low-Rank Adaptation (DoRA) + +This technique decomposes the updates of the weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference, see [`LoraModel.merge_and_unload`]. For more information on DoRA, see https://arxiv.org/abs/2402.09353. + +```py +from peft import LoraConfig + +config = LoraConfig(use_dora=True, ...) +``` + ### QLoRA-style training The default LoRA settings in PEFT add trainable weights to the query and value layers of each attention block. But [QLoRA](https://hf.co/papers/2305.14314), which adds trainable weights to all the linear layers of a transformer model, can provide performance equal to a fully finetuned model. To apply LoRA to all the linear layers, like in QLoRA, set `target_modules="all-linear"` (easier than specifying individual modules by name which can vary depending on the architecture). diff --git a/src/peft/tuners/lora/bnb.py b/src/peft/tuners/lora/bnb.py index 06a6f431cd..02a14ac169 100644 --- a/src/peft/tuners/lora/bnb.py +++ b/src/peft/tuners/lora/bnb.py @@ -38,13 +38,25 @@ def __init__( lora_dropout: float = 0.0, init_lora_weights: bool = True, use_rslora: bool = False, + use_dora: bool = False, **kwargs, ) -> None: super().__init__() LoraLayer.__init__(self, base_layer) + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + self._active_adapter = adapter_name - self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: """ @@ -216,13 +228,25 @@ def __init__( lora_dropout: float = 0.0, init_lora_weights: bool = True, use_rslora: bool = False, + use_dora: bool = False, **kwargs, ) -> None: super().__init__() LoraLayer.__init__(self, base_layer) + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + self._active_adapter = adapter_name - self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: """ diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 0df8b5b623..695fed6f82 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -101,6 +101,13 @@ class LoraConfig(PeftConfig): The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a quantized model in this case, as LoftQ will quantize the model itself. + use_dora (`bool`): + Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights + into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is + handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low + ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than + pure LoRA, so it is recommended to merge weights for inference. For more information, see + https://arxiv.org/abs/2402.09353. """ r: int = field(default=8, metadata={"help": "Lora attention dimension"}) @@ -224,6 +231,19 @@ class LoraConfig(PeftConfig): ) }, ) + use_dora: bool = field( + default=False, + metadata={ + "help": ( + "Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the " + "weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the " + "magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, " + "especially at low ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces " + "a bigger overhead than pure LoRA, so it is recommended to merge weights for inference. For more " + "information, see https://arxiv.org/abs/2402.09353." + ) + }, + ) def __post_init__(self): self.peft_type = PeftType.LORA @@ -238,6 +258,9 @@ def __post_init__(self): if isinstance(self.target_modules, str) and self.layers_pattern is not None: raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") + if self.use_dora and (self.megatron_config or self.init_lora_weights == "loftq"): + raise ValueError("DoRA does not support megatron_core or LoftQ. Please set `use_dora=False`.") + # handle init_lora_weights and loftq_config if self.init_lora_weights == "loftq": import importlib diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index c7d7ceefc5..333dfa6feb 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -31,16 +31,28 @@ def __init__( lora_dropout: float = 0.0, init_lora_weights: bool = True, use_rslora: bool = False, + use_dora: bool = False, **kwargs, ): super().__init__() LoraLayer.__init__(self, base_layer) + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter # for backwards compatibility self.quant_linear_module = base_layer self._active_adapter = adapter_name - self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) def forward(self, x: torch.Tensor): # note: logic differs from default Linear because merging is not supported diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 52049783c2..83617b69a2 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -22,6 +22,7 @@ from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.integrations import gather_params_ctx from peft.utils.other import transpose from .config import LoraConfig @@ -47,6 +48,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] + self.use_dora: dict[str, bool] = {} + self.lora_magnitude_vector: Optional[torch.nn.ParameterDict] = None # for DoRA + self._caches: dict[str, Any] = {} self.kwargs = kwargs base_layer = self.get_base_layer() @@ -78,7 +82,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: self.in_features = in_features self.out_features = out_features - def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora): + def update_layer( + self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora: bool = False + ): # This code works for linear layers, override for other layer types if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") @@ -114,6 +120,13 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig else: self.to(weight.device) break + + if use_dora: + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + self.set_adapter(self.active_adapters) def reset_lora_parameters(self, adapter_name, init_lora_weights): @@ -156,6 +169,65 @@ def loftq_init(self, adapter_name): self.lora_embedding_B[adapter_name].weight.data = lora_B self.get_base_layer().weight.data = qweight + def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: + # calculate L2 norm of weight matrix, column-wise + weight = weight + scaling * lora_weight + weight_norm = torch.linalg.norm(weight, dim=1) + return weight_norm + + def dora_init(self, adapter_name: str) -> None: + lora_A = self.lora_A[adapter_name] + lora_B = self.lora_B[adapter_name] + scaling = self.scaling[adapter_name] + with gather_params_ctx(self.get_base_layer()): + weight = self.get_base_layer().weight + lora_weight = lora_B.weight @ lora_A.weight + weight_norm = self._get_weight_norm(weight, lora_weight, scaling) + self.lora_magnitude_vector = nn.ParameterDict() + self.lora_magnitude_vector[adapter_name] = nn.Parameter(weight_norm, requires_grad=True) + # add lora_magnitude_vector to the list of learnable parameters + self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) + + def _cache_store(self, key: str, value: Any) -> None: + self._caches[key] = value + + def _cache_pop(self, key: str) -> Any: + value = self._caches.pop(key) + return value + + def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): + """ + For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer + output. + """ + lora_weight = lora_B.weight @ lora_A.weight + magnitude = self.lora_magnitude_vector[active_adapter] + weight = self.get_base_layer().weight + weight_norm = self._get_weight_norm(weight, lora_weight, scaling) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + weight_norm = weight_norm.detach() + mag_norm_scale = (magnitude / weight_norm).view(1, -1) + result_dora = (mag_norm_scale - 1) * ( + F.linear(x, transpose(weight, self.fan_in_fan_out)) + ) + mag_norm_scale * lora_B(lora_A(x)) * scaling + + # Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again. + # This is only correct if dropout=0, otherwise results will differ: + # https://github.com/huggingface/peft/pull/1474#issuecomment-1964682771 + # bias = self.get_base_layer().bias + # if bias is not None: + # result = result - bias + # result = mag_norm_scale * result + mag_norm_scale * lora_B(lora_A(x)) * scaling + # if bias is not None: + # result = result + bias + + return result_dora + def set_scale(self, adapter, scale): if adapter not in self.scaling: # Ignore the case where the adapter is not in the layer @@ -206,6 +278,7 @@ def __init__( is_target_conv_1d_layer: bool = False, init_lora_weights: Union[bool, str] = True, use_rslora: bool = False, + use_dora: bool = False, **kwargs, ) -> None: super().__init__() @@ -213,7 +286,15 @@ def __init__( self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name - self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) self.is_target_conv_1d_layer = is_target_conv_1d_layer def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: @@ -241,7 +322,19 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N # Note that safe_merge will be slower than the normal merge # because of the copy operation. orig_weights = base_layer.weight.data.clone() - orig_weights += self.get_delta_weight(active_adapter) + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + orig_weights += delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + orig_weights = dora_factor.view(-1, 1) * (orig_weights + delta_weight) if not torch.isfinite(orig_weights).all(): raise ValueError( @@ -250,7 +343,21 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N base_layer.weight.data = orig_weights else: - base_layer.weight.data += self.get_delta_weight(active_adapter) + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + base_layer.weight.data += delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + new_weight = dora_factor.view(-1, 1) * (base_layer.weight.data + delta_weight) + base_layer.weight.data = new_weight + self.merged_adapters.append(active_adapter) def unmerge(self) -> None: @@ -263,7 +370,15 @@ def unmerge(self) -> None: while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.lora_A.keys(): - self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + weight = self.get_base_layer().weight + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + weight.data -= delta_weight + else: + weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight + weight.data = weight_orig def get_delta_weight(self, adapter) -> torch.Tensor: """ @@ -317,7 +432,12 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] x = x.to(lora_A.weight.dtype) - result += lora_B(lora_A(dropout(x))) * scaling + + if not self.use_dora[active_adapter]: + result = result + lora_B(lora_A(dropout(x))) * scaling + else: + x = dropout(x) + result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) result = result.to(torch_result_dtype) return result @@ -338,15 +458,27 @@ def __init__( lora_dropout: float = 0.0, init_lora_weights: Union[bool, str] = True, use_rslora: bool = False, + use_dora: bool = False, **kwargs, ) -> None: super().__init__() LoraLayer.__init__(self, base_layer) + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + self._active_adapter = adapter_name - self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) - def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora): + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") @@ -514,15 +646,27 @@ def __init__( lora_dropout: float = 0.0, init_lora_weights: Union[bool, str] = True, use_rslora: bool = False, + use_dora: bool = False, **kwargs, ) -> None: super().__init__() LoraLayer.__init__(self, base_layer) + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + self._active_adapter = adapter_name - self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) - def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora): + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index e0c0ec4a2f..a6b7a29318 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -154,6 +154,7 @@ def _create_and_replace( "fan_in_fan_out": lora_config.fan_in_fan_out, "init_lora_weights": lora_config.init_lora_weights, "use_rslora": lora_config.use_rslora, + "use_dora": lora_config.use_dora, "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), } @@ -171,10 +172,11 @@ def _create_and_replace( target.update_layer( adapter_name, r, - alpha, - lora_config.lora_dropout, - lora_config.init_lora_weights, - lora_config.use_rslora, + lora_alpha=alpha, + lora_dropout=lora_config.lora_dropout, + init_lora_weights=lora_config.init_lora_weights, + use_rslora=lora_config.use_rslora, + use_dora=lora_config.use_dora, ) else: new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) diff --git a/src/peft/tuners/lora/tp_layer.py b/src/peft/tuners/lora/tp_layer.py index f3e8215234..062c1cc615 100644 --- a/src/peft/tuners/lora/tp_layer.py +++ b/src/peft/tuners/lora/tp_layer.py @@ -44,11 +44,15 @@ def __init__( fan_in_fan_out: bool = False, init_lora_weights: bool = True, use_rslora: bool = False, + use_dora: bool = False, **kwargs, ): super().__init__() LoraLayer.__init__(self, base_layer=base_layer) + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + self.backend = backend self.is_parallel_a = isinstance(base_layer, backend.RowParallelLinear) self.fan_in_fan_out = fan_in_fan_out @@ -68,13 +72,14 @@ def __init__( self.update_layer( adapter_name, r, - lora_alpha, - lora_dropout, - init_lora_weights, - use_rslora, - init_method, - input_is_parallel, - gather_output, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + init_method=init_method, + input_is_parallel=input_is_parallel, + gather_output=gather_output, **parallel_linear_kwargs, ) @@ -97,6 +102,7 @@ def update_layer( lora_dropout, init_lora_weights, use_rslora, + use_dora=False, init_method=init.xavier_normal_, input_is_parallel=True, gather_output=False, diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py new file mode 100644 index 0000000000..3b0e139483 --- /dev/null +++ b/src/peft/utils/integrations.py @@ -0,0 +1,39 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager + +import packaging.version +import torch +import transformers + + +@contextmanager +def gather_params_ctx(module: torch.nn.Module, modifier_rank: int = 0): + """Call DeepSpeed GatheredParameters context manager if DeepSpeed is enabled, otherwise do nothing.""" + if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.33.0"): + from transformers.integrations import is_deepspeed_zero3_enabled + else: + from transformers.deepspeed import is_deepspeed_zero3_enabled + + if not is_deepspeed_zero3_enabled(): + yield + return + + import deepspeed + + params_to_gather = module.parameters() + with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=modifier_rank): + yield + return diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index f53da815c1..6d5f9f74e5 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -55,6 +55,14 @@ "lora_dropout": 0.1, }, ), + ("Vanilla MLP 7 LoRA with DoRA", "MLP", LoraConfig, {"target_modules": ["lin0"], "use_dora": True}), + ("Vanilla MLP 8 LoRA with DoRA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"], "use_dora": True}), + ( + "Vanilla MLP 9 LoRA with DoRA", + "MLP", + LoraConfig, + {"target_modules": "lin1", "use_dora": True, "lora_alpha": 32}, + ), ("Embedding + transformers Conv1D 1 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["conv1d"]}), ("Embedding + transformers Conv1D 2 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb"]}), ("Embedding + transformers Conv1D 3 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb", "conv1d"]}), @@ -558,7 +566,8 @@ def test_forward_output_finite(self, test_name, model_id, config_cls, config_kwa @parameterized.expand(TEST_CASES) def test_only_params_are_updated(self, test_name, model_id, config_cls, config_kwargs): - # An explicit test that when using LoRA on a custom model, only the LoRA parameters are updated during training + # An explicit test that when using an adapter on a custom model, only the adapter parameters are updated during + # training X = self.prepare_inputs_for_testing() model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) config = config_cls( @@ -606,7 +615,8 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c ) model = get_peft_model(model, config) model.train() - optimizer = torch.optim.SGD(model.parameters(), lr=0.5) + lr = 0.5 if not config_kwargs.get("use_dora") else 0.1 # otherwise we get nan + optimizer = torch.optim.SGD(model.parameters(), lr=lr) # train at least 3 steps for all parameters to be updated (probably this is required because of symmetry # breaking of some LoRA layers that are initialized with constants) @@ -706,6 +716,7 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co optimizer.step() model.eval() + outputs_unmerged = model(**X) model.merge_adapter() outputs_after = model(**X) @@ -723,6 +734,9 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co # check that there is a difference in results after training assert not torch.allclose(outputs_before, outputs_after, atol=atol, rtol=rtol) + # unmerged or merged should make no difference + assert torch.allclose(outputs_after, outputs_unmerged, atol=atol, rtol=rtol) + # check that disabling adapters gives the same results as before training assert torch.allclose(outputs_before, outputs_disabled, atol=atol, rtol=rtol) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 5e05fa5f33..9e8ef21f3f 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - +import pytest import torch from scipy import stats from torch import nn @@ -22,7 +21,7 @@ from peft.utils import infer_device -class InitializationTest(unittest.TestCase): +class TestInitialization: """Test class to check the initialization of adapters.""" torch_device = infer_device() @@ -47,10 +46,16 @@ def __init__(self): self.conv2d = nn.Conv2d(100, 100, 3) def forward(self, x): - return self.linear(x) + x_int = (100 * x).int() + x_4d = x.flatten().reshape(1, 100, 10, 10) + return self.linear(x), self.embed(x_int), self.conv2d(x_4d) return MyModule().eval().to(self.torch_device) + @pytest.fixture + def data(self): + return torch.rand(10, 1000).to(self.torch_device) + def test_lora_linear_init_default(self): # default is True torch.manual_seed(0) @@ -315,3 +320,46 @@ def test_rslora_scaling_pattern(self): assert model.linear.scaling["default"] == expected_scaling["linear"] assert model.embed.scaling["default"] == expected_scaling["embed"] assert model.conv2d.scaling["default"] == expected_scaling["conv2d"] + + def test_use_dora_linear(self, data): + # check that dora is a no-op when initialized + torch.manual_seed(0) + model = self.get_model() + output_base, _, _ = model(data) + + # check scaling factor use_rslora=True + config = LoraConfig(target_modules=["linear"], use_dora=True) + model = get_peft_model(model, config) + + with model.disable_adapter(): + output_disabled, _, _ = model(data) + output_dora, _, _ = model(data) + + assert torch.allclose(output_base, output_disabled) + assert torch.allclose(output_base, output_dora) + + def test_use_dora_linear_init_false(self, data): + # with init_lora_weights=False, dora should not be a no-op + torch.manual_seed(0) + model = self.get_model() + output_base, _, _ = model(data) + + # check scaling factor use_rslora=True + config = LoraConfig(target_modules=["linear"], use_dora=True, init_lora_weights=False) + model = get_peft_model(model, config) + + with model.disable_adapter(): + output_disabled, _, _ = model(data) + output_dora, _, _ = model(data) + + assert torch.allclose(output_base, output_disabled) + assert not torch.allclose(output_base, output_dora) + + def test_use_dora_with_loftq_raises(self): + with pytest.raises(ValueError, match="DoRA does not support megatron_core or LoftQ"): + LoraConfig(target_modules=["linear"], use_dora=True, init_lora_weights="loftq") + + def test_use_dora_with_megatron_core_raises(self): + megatron_config = {"does-not": "matter-here"} + with pytest.raises(ValueError, match="DoRA does not support megatron_core or LoftQ"): + LoraConfig(target_modules=["linear"], use_dora=True, megatron_config=megatron_config)