forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lora.py
155 lines (137 loc) · 6.06 KB
/
lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import List
import torch.nn.functional as F
from torch import nn, Tensor
from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
from torchtune.modules.peft.peft_utils import AdapterModule
from torchtune.utils import _register_nf4_dispatch_ops # noqa: F401
class LoRALinear(nn.Module, AdapterModule):
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_.
LoRA perturbs a given layer via a low-rank approximation where only
the rank decomposition matrices are trainable. In a linear layer instead of
:math:`x \\mapsto W_0x` a LoRALinear layer is defined as
:math:`x \\mapsto W_0x + (\\alpha / r)BAx`, where :math:`r` is the rank of
the matrices :math:`A` and :math:`B` and :math:`\\alpha` is a scaling factor.
As in the original implementation, we support dropout before multiplication
by the low-rank matrices.
Args:
in_dim (int): input dimension
out_dim (int): output dimension
rank (int): rank of the low-rank approximation
alpha (float): scaling factor for the low-rank approximation
dropout (float): dropout probability. Default: 0.0
use_dora (bool): whether to use DORA (weight-Decomposed Low-Rank Adaptation).
Default: False
use_bias (bool): whether to include bias in the original linear layer.
Default: False
quantize_base (bool): Whether to quantize base linear weight or not.
Default: False
"""
def __init__(
self,
in_dim: int,
out_dim: int,
rank: int,
alpha: float,
dropout: float = 0.0,
use_dora: bool = False,
use_bias: bool = False,
quantize_base: bool = False,
):
super().__init__()
self.in_dim = in_dim
self.rank = rank
self.alpha = alpha
self.out_dim = out_dim
self.use_bias = use_bias
self.use_dora = use_dora
self._quantize_base = quantize_base
weight, bias = self._create_weight_and_bias()
# 'self.disabled' is a flag showing whether to turn off LoRA adapters,
# this can be used in DPO for treating the lora adapters as the policy model
# and disabling it to treat the base model as the reference model
self.disabled = False
self.register_parameter("weight", nn.Parameter(weight))
self.register_parameter(
"bias", nn.Parameter(bias) if bias is not None else None
)
self.dropout = nn.Dropout(p=dropout)
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
self.m = nn.Parameter(F.ones(1, out_dim)) if self.use_dora else None
self.merged = False
# Note: FSDP's meta device initialization contract assumes that a module's
# reset_parameters method only initializes its own parameters (i.e. no child
# params are initialized, as is done in initialize_parameters below).
# For that reason, we patch reset_parameters directly on lora_a and lora_b submodules
# when using meta device. This is done in
# torchtune.utils.prepare_model_for_fsdp_with_meta_device.
# See this issue for more details: https://github.com/pytorch/pytorch/issues/104187.
# Without meta device, we only need the following:
self.initialize_parameters()
def initialize_parameters(self):
# Initialize as in
# https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L119
_lora_a_init_params(self.lora_a)
_lora_b_init_params(self.lora_b)
def _create_weight_and_bias(self):
"""
Creates a linear weight and bias tensor, using NF4 dtype if we're quantizing
(indicated via quantize_base=True).
"""
in_dim, out_dim, use_bias = self.in_dim, self.out_dim, self.use_bias
linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=use_bias)
weight = linear.weight if not self._quantize_base else to_nf4(linear.weight)
bias = None
if self.use_bias:
if self._quantize_base:
raise NotImplementedError(
"Quantized LoRALinear does not support bias at the moment."
)
bias = linear.bias
return weight, bias
def adapter_params(self) -> List[str]:
"""
Return lora_a.weight and lora_b.weight as adapter params.
If bias is enabled, also return lora_a.bias and lora_b.bias.
"""
# NOTE: this function has to be updated if the names of "lora_a" and "lora_b"
# in this module change.
adapter_params = ["lora_a.weight", "lora_b.weight"]
return adapter_params
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): input tensor with shape ``(..., in_dim)``
Returns:
Tensor: output tensor with shape ``(..., out_dim)``
"""
if self._quantize_base:
out = linear_nf4(input=x, weight=self.weight)
else:
out = F.linear(x, self.weight, self.bias)
if self.disabled:
return out
lora_out = self.lora_a(self.dropout(x))
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
# Adding 1e-6 to avoid division by zero
if self.use_dora:
return out + self.m * lora_out / (
lora_out.norm(p=2, dim=-1, keepdim=True) + 1e-6
)
return out + lora_out
def _lora_a_init_params(x: nn.Linear) -> None:
"""
Initialize LoRA A weight to Kaiming uniform.
"""
nn.init.kaiming_uniform_(x.weight, a=math.sqrt(5))
def _lora_b_init_params(x: nn.Linear) -> None:
"""
Initialize LoRA B weight to zeros.
"""
nn.init.zeros_(x.weight)