Skip to content

Conversation

@lugimzzz
Copy link
Contributor

PR types

New features

PR changes

APIs

Description

loralinear

@paddle-bot
Copy link

paddle-bot bot commented Apr 10, 2025

Thanks for your contribution!

@codecov
Copy link

codecov bot commented Apr 11, 2025

Codecov Report

Attention: Patch coverage is 19.03766% with 387 lines in your changes missing coverage. Please review.

Project coverage is 48.95%. Comparing base (f232d82) to head (2d50107).
Report is 205 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/quantization/quantization_linear.py 13.20% 92 Missing ⚠️
paddlenlp/peft/lora/lora_quantization_layers.py 16.66% 85 Missing ⚠️
paddlenlp/utils/optimizer.py 8.23% 78 Missing ⚠️
paddlenlp/quantization/qat_utils.py 17.72% 65 Missing ⚠️
paddlenlp/quantization/hadamard_utils.py 10.00% 27 Missing ⚠️
paddlenlp/quantization/quantization_utils.py 28.57% 25 Missing ⚠️
paddlenlp/trainer/trainer.py 28.57% 5 Missing ⚠️
paddlenlp/peft/lora/lora_model.py 77.77% 4 Missing ⚠️
paddlenlp/peft/lora/lora_layers.py 25.00% 3 Missing ⚠️
paddlenlp/transformers/model_utils.py 66.66% 2 Missing ⚠️
... and 1 more
Additional details and impacted files
@@             Coverage Diff             @@
##           develop   #10385      +/-   ##
===========================================
- Coverage    49.08%   48.95%   -0.13%     
===========================================
  Files          763      767       +4     
  Lines       125673   126153     +480     
===========================================
+ Hits         61689    61764      +75     
- Misses       63984    64389     +405     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

self.disable_lora = False
if mp_moe or is_distributed:
for p in self.parameters():
p.is_distributed = is_distributed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用于EP,is_distributed标识训练开始的时候不要同步参数和mp_moe用于uc

level=self.args.fp16_opt_level,
dtype=self.amp_dtype,
excluded_layers=[QuantizationLinear] + self._decorate_exclude_layers(model),
excluded_layers=[QuantizationLinear, ColumnParallelQuantizationLinear, RowParallelQuantizationLinear]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

防止精度为fp32的量化scale被cast成bf16


# Optimize for skip unused shard files for supper large model
if sharded_metadata is not None and quantization_linear_list is None:
if sharded_metadata is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip掉不需要读取的参数分片,加速加载

@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个貌似不用改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

new_weight += self.lora_A @ self.lora_B * self.scaling
self.quantize_weight(new_weight)
self.merged = True
mp_moe = getattr(self.quant_weight, "mp_moe", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有太明白这里为什么一定对MoE的参数进行标识

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是unified checkpoint需要使用

@@ -0,0 +1,154 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hadamard_utils.py 的来源是来自哪里了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slim同学给的,现在去掉不需要使用的部分

from .hadamard_utils import random_hadamard_matrix


def quantize_tensorwise(x, quantization_config=None, bit_length=8, state=0, training=False, act_scale=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要单独对QAT用到的量化方法单独一个文件吗?是不是所有的量化方法都在一起比较好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qat方法比较复杂,后续还会加比较多东西,所以单独写一个qat_utils

if quantization_config.apply_hadamard:
target_x = x @ infohub.hadamard[x.shape[-1]][0]
else:
target_x = x.clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里clone的原因?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉了

input_grad = None

if not quant_weight.stop_gradient:
weight_grad = paddle.einsum("bsh,bsd->hd", x, grad_output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle的einsum在某些场景下有坑,看看是否适合用enisum

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

果然有问题!einsum比matmul慢好多,我换成matmul了

wawltor
wawltor previously approved these changes Apr 21, 2025
Copy link
Contributor

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit 72e1994 into PaddlePaddle:develop Apr 23, 2025
9 of 15 checks passed
@lugimzzz lugimzzz deleted the ssp branch April 25, 2025 08:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants