Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

partial implementation of lqlora #8324

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
from paddlenlp.utils.log import logger
from paddlenlp.peft.lora.lqlora_utils import transform_lora_layers

# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
os.environ["USE_CASUAL_MASK"] = "False"
Expand Down Expand Up @@ -465,6 +466,9 @@ def neft_post_hook(module, input, output):
else:
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)

if model_args.lqlora:
transform_lora_layers(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

传入到lora_config lqlora来控制


model.print_trainable_parameters()

def compute_metrics_do_generation(eval_preds):
Expand Down
57 changes: 57 additions & 0 deletions paddlenlp/peft/lora/lqlora_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from .lora_model import LoRAModel
from .lora_layers import LoRALinear

Check warning on line 2 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L1-L2

Added lines #L1 - L2 were not covered by tests

import paddle
from paddlenlp.quantization.qlora import qlora_weight_quantize_dequantize

Check warning on line 5 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L4-L5

Added lines #L4 - L5 were not covered by tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议把lqlora初始化的过程写成一个lqlora_init的函数,通过lora_config传入是否使用lqlora,考虑在621行前对lora_module apply这个lqlora_init,https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/peft/lora/lora_model.py#L621


def transform_lora_layers(

Check warning on line 8 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L8

Added line #L8 was not covered by tests
model: LoRAModel,
num_iterations: int = 100
) -> None:
if not isinstance(model, LoRAModel):
raise NotImplementedError(f"Unknown model type: {type(model)}")
for name, submodule in model.named_sublayers():
if isinstance(submodule, LoRALinear):
num_ranks = submodule.r
W = submodule.weight

Check warning on line 17 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L12-L17

Added lines #L12 - L17 were not covered by tests

if W.dtype in [paddle.float16]:
old_dtype = W.dtype
W = paddle.cast(W, dtype=paddle.float32)

Check warning on line 21 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L19-L21

Added lines #L19 - L21 were not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast成fp32的原因?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
old_dtype = None

Check warning on line 23 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L23

Added line #L23 was not covered by tests

Q = paddle.zeros_like(W)
last_error = paddle.to_tensor(float("inf"), dtype=W.dtype)
for i in range(num_iterations):
A = W - Q
if A.ndim != 2:
raise ValueError(f"Expected 2D Matrix, but got {A.ndim}.")

Check warning on line 30 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L25-L30

Added lines #L25 - L30 were not covered by tests

U, S, Vh = paddle.linalg.svd(A, full_matrices=False)
Ur = U[:, :num_ranks]
Sr = S[:num_ranks]
Vhr = Vh[:num_ranks]

Check warning on line 35 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L32-L35

Added lines #L32 - L35 were not covered by tests

lora_A = Ur @ paddle.diag(paddle.sqrt(Sr))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

配置的时候需要考虑lora scaling,看起来lora scaling只能强制为1

lora_B = paddle.diag(paddle.sqrt(Sr)) @ Vhr

Check warning on line 38 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L37-L38

Added lines #L37 - L38 were not covered by tests

Q = qlora_weight_quantize_dequantize(W-lora_A@lora_B, double_quant=True)

Check warning on line 40 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L40

Added line #L40 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double_quant=True,应该作为一个可调节参数,qlora_weight_quantize_dequantize中的其他参数也一样


W_ = Q + lora_A@lora_B
error = paddle.norm(W - W_, p = "fro")

Check warning on line 43 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L42-L43

Added lines #L42 - L43 were not covered by tests

if error > last_error:
print("break.")
break
last_error = error

Check warning on line 48 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L45-L48

Added lines #L45 - L48 were not covered by tests

if old_dtype is not None:
lora_A = paddle.cast(lora_A, dtype=old_dtype)
lora_B = paddle.cast(lora_B, dtype=old_dtype)
Q = paddle.cast(Q, dtype=old_dtype)

Check warning on line 53 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L50-L53

Added lines #L50 - L53 were not covered by tests

submodule.lora_A.set_value(lora_A)
submodule.lora_B.set_value(lora_B)
submodule.weight.set_value(Q)

Check warning on line 57 in paddlenlp/peft/lora/lqlora_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lqlora_utils.py#L55-L57

Added lines #L55 - L57 were not covered by tests
Loading