Skip to content

Commit

Permalink
FEAT Implement DoRA (huggingface#1474)
Browse files Browse the repository at this point in the history
Add DoRA (Weight-Decomposed Low-Rank Adaptation).

https://arxiv.org/abs/2402.09353

To use this with LoRA, add use_dora=True to the LoraConfig.

Currently only supports nn.Linear layers, not other types or
quantized linear layers like bnb.
  • Loading branch information
BenjaminBossan committed Mar 14, 2024
1 parent e137d30 commit 6f5b9d4
Show file tree
Hide file tree
Showing 10 changed files with 352 additions and 30 deletions.
10 changes: 10 additions & 0 deletions docs/source/developer_guides/lora.md
Expand Up @@ -69,6 +69,16 @@ from peft import LoraConfig
config = LoraConfig(use_rslora=True, ...)
```

### Weight-Decomposed Low-Rank Adaptation (DoRA)

This technique decomposes the updates of the weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference, see [`LoraModel.merge_and_unload`]. For more information on DoRA, see https://arxiv.org/abs/2402.09353.

```py
from peft import LoraConfig

config = LoraConfig(use_dora=True, ...)
```

### QLoRA-style training

The default LoRA settings in PEFT add trainable weights to the query and value layers of each attention block. But [QLoRA](https://hf.co/papers/2305.14314), which adds trainable weights to all the linear layers of a transformer model, can provide performance equal to a fully finetuned model. To apply LoRA to all the linear layers, like in QLoRA, set `target_modules="all-linear"` (easier than specifying individual modules by name which can vary depending on the architecture).
Expand Down
28 changes: 26 additions & 2 deletions src/peft/tuners/lora/bnb.py
Expand Up @@ -38,13 +38,25 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)

if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Expand Down Expand Up @@ -216,13 +228,25 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)

if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Expand Down
23 changes: 23 additions & 0 deletions src/peft/tuners/lora/config.py
Expand Up @@ -101,6 +101,13 @@ class LoraConfig(PeftConfig):
The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights
and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a
quantized model in this case, as LoftQ will quantize the model itself.
use_dora (`bool`):
Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights
into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is
handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low
ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than
pure LoRA, so it is recommended to merge weights for inference. For more information, see
https://arxiv.org/abs/2402.09353.
"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -224,6 +231,19 @@ class LoraConfig(PeftConfig):
)
},
)
use_dora: bool = field(
default=False,
metadata={
"help": (
"Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the "
"weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the "
"magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, "
"especially at low ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces "
"a bigger overhead than pure LoRA, so it is recommended to merge weights for inference. For more "
"information, see https://arxiv.org/abs/2402.09353."
)
},
)

def __post_init__(self):
self.peft_type = PeftType.LORA
Expand All @@ -238,6 +258,9 @@ def __post_init__(self):
if isinstance(self.target_modules, str) and self.layers_pattern is not None:
raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.")

if self.use_dora and (self.megatron_config or self.init_lora_weights == "loftq"):
raise ValueError("DoRA does not support megatron_core or LoftQ. Please set `use_dora=False`.")

# handle init_lora_weights and loftq_config
if self.init_lora_weights == "loftq":
import importlib
Expand Down
14 changes: 13 additions & 1 deletion src/peft/tuners/lora/gptq.py
Expand Up @@ -31,16 +31,28 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
):
super().__init__()
LoraLayer.__init__(self, base_layer)

if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
self.quant_linear_module = base_layer
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)

def forward(self, x: torch.Tensor):
# note: logic differs from default Linear because merging is not supported
Expand Down

0 comments on commit 6f5b9d4

Please sign in to comment.