From a1a40aff1a13adb73c1a21a78595dd6a83369594 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 4 Nov 2025 12:18:09 +0000 Subject: [PATCH 1/3] support v1 loader in wint8 --- .../layers/backends/xpu/moe/fused_moe.py | 114 ++++++++++++++++-- .../backends/xpu/quantization/weight_only.py | 94 ++++++++++++--- 2 files changed, 178 insertions(+), 30 deletions(-) diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index 20f47cf36f7..9d3bd038f70 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -30,6 +30,7 @@ xpu_moe_layer, ) from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs +from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs class XPUMoEMethod(MoEMethodBase): @@ -62,20 +63,14 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): """ create weight process. """ - if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in ["w16a16"]: - self.up_gate_proj_weight_shape = [ - layer.num_local_experts, - layer.moe_intermediate_size * 2, - layer.hidden_size, - ] + if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in ["w16a16", "weight_only_int8"]: + self.up_gate_proj_weight_shape = [layer.num_local_experts,layer.moe_intermediate_size * 2, layer.hidden_size,] self.down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size] - extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}} - layer.up_gate_proj_weight = layer.create_parameter( - shape=self.up_gate_proj_weight_shape, - dtype=layer.weight_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ) + shape=self.up_gate_proj_weight_shape, + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) layer.down_proj_weight = layer.create_parameter( shape=self.down_proj_weight_shape, @@ -86,18 +81,22 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): set_weight_attrs( layer.up_gate_proj_weight, { + "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}, "weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)), "weight_need_transpose": extra_weight_attrs.get("model_format") == "torch", + "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=False), + }, ) set_weight_attrs( layer.down_proj_weight, { + "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}, "weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)), "weight_need_transpose": extra_weight_attrs.get("model_format") == "torch", + "tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=True), }, ) - if layer.with_bias: layer.up_gate_proj_bias = layer.create_parameter( shape=[layer.num_experts, layer.moe_intermediate_size * 2], @@ -128,6 +127,17 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): "model_format": extra_weight_attrs.get("model_format", ""), }, ) + if self.moe_quant_type in ["weight_only_int8"]: + self.up_gate_proj_scale_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + ] + self.down_proj_scale_shape = [ + layer.num_local_experts, + layer.hidden_size, + ] + + else: self.up_gate_proj_weight_shape = [ @@ -536,6 +546,84 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict): quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) getattr(layer, scale_name).set_value(quanted_weight_scale) + def process_weights_after_loading(self, layer): + """ """ + if not self.quant_config.is_checkpoint_bf16: + return + weight_id_map = {"gate_up": 0, "down": 1} + if ( + hasattr(layer.up_gate_proj_weight, "tensor_track") + and layer.up_gate_proj_weight.tensor_track is not None + and layer.up_gate_proj_weight.tensor_track.is_fully_copied() + ): + weight_type = "gate_up" + else: + weight_type = "down" + + # 1.init shape and type + # weight + weight_name = self.added_weight_attrs[weight_id_map[weight_type]] + unquantized_weight_name = weight_name.replace("quant_weight", "weight") + if weight_type == "gate_up": + weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + else: + weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size, + ] + weight_dtype = "int8" + # scale + scale_name = self.added_scale_attrs[weight_id_map[weight_type]] + scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape + scale_dtype = "float32" + + # 2.crate tmp tensor + + # weight = paddle.empty(weight_shape, dtype=weight_dtype) + # scale = paddle.empty(scale_shape, dtype=scale_dtype) + + # 3.quantize weight + weight_list = [] + weight_scale_list = [] + for expert_id in range(layer.num_local_experts): + quant_weight, scale = weight_quantize_xpu(getattr(layer, unquantized_weight_name)[expert_id].transpose([1, 0]), self.moe_quant_type,-1,-1) + weight_list.append(quant_weight.transpose([1, 0])) + weight_scale_list.append(scale) + quanted_weight = paddle.stack(weight_list, axis=0) + quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) + + free_tensor(getattr(layer, unquantized_weight_name)) + + + # create weight + setattr( + layer, + weight_name, + layer.create_parameter( + shape=weight_shape, + dtype=weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # create scale + setattr( + layer, + scale_name, + layer.create_parameter( + shape=scale_shape, + dtype=scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + getattr(layer, weight_name).set_value(quanted_weight) + getattr(layer, scale_name).set_value(quanted_weight_scale) + class XPUW4A8MoEMethod(XPUMoEMethod): """ diff --git a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py index fa95561d4d4..1212dd44af1 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py @@ -22,7 +22,12 @@ WeightOnlyLinearMethod, ) from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu - +from fastdeploy.model_executor.layers.linear import ( + MergedColumnParallelLinear, + MergedReplicatedLinear, + QKVParallelLinear, +) +from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod): """ @@ -41,22 +46,48 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs) -> None: Create weights for linear layer on XPU """ # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. - weight_scale_shape = [layer.weight_shape[1]] - layer.weight_shape.reverse() - if self.quant_config.name() == "weight_only_int4": - layer.weight_shape[0] //= 2 - layer.weight_dtype = "int8" - layer.weight = layer.create_parameter( - shape=layer.weight_shape, - dtype=layer.weight_dtype, - is_bias=False, - default_initializer=paddle.nn.initializer.Constant(0), - ) - layer.weight_scale = layer.create_parameter( - shape=weight_scale_shape, - dtype="float32", - is_bias=False, - ) + if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1": + layer.weight = layer.create_parameter( + shape=layer.weight_shape, + dtype=layer.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch" + quant_attrs = extra_weight_attrs + if ( + isinstance(layer, MergedColumnParallelLinear) + or isinstance(layer, QKVParallelLinear) + or isinstance(layer, MergedReplicatedLinear) + ): + quant_attrs = { + **extra_weight_attrs, + "tensor_track": TensorTracker( + shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim", True) + ), + } + set_weight_attrs( + layer.weight, + quant_attrs, + ) + else: + # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. + weight_scale_shape = [layer.weight_shape[1]] + layer.weight_shape.reverse() + if self.quant_config.name() == "weight_only_int4": + layer.weight_shape[0] //= 2 + layer.weight_dtype = "int8" + layer.weight = layer.create_parameter( + shape=layer.weight_shape, + dtype=layer.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.weight_scale = layer.create_parameter( + shape=weight_scale_shape, + dtype="float32", + is_bias=False, + ) def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None: """ @@ -76,3 +107,32 @@ def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None weight_scale_tensor = paddle.concat(weight_scale_tensors, axis=0) layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0])) layer.weight_scale.set_value(weight_scale_tensor) + + + def process_weights_after_loading(self, layer) -> None: + if not self.quant_config.is_checkpoint_bf16: + return + + quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu( + layer.weight, + self.quant_config.algo, + -1, + -1 + ) + + free_tensor(layer.weight) + + layer.weight = layer.create_parameter( + shape=quanted_weight_tensor.shape[::-1], + dtype="int8" , + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.weight_scale = layer.create_parameter( + shape=weight_scale_tensor.shape, + dtype=weight_scale_tensor.dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0])) + layer.weight_scale.copy_(weight_scale_tensor, False) \ No newline at end of file From ad5171f508857e7b48841571c5004c494cc2fc02 Mon Sep 17 00:00:00 2001 From: iosmers Date: Tue, 4 Nov 2025 12:23:07 +0000 Subject: [PATCH 2/3] code style --- .../layers/backends/xpu/moe/fused_moe.py | 35 ++++++++++++------- .../backends/xpu/quantization/weight_only.py | 23 +++++------- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index 9d3bd038f70..749e5ca7a02 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -29,8 +29,12 @@ weight_quantize_xpu, xpu_moe_layer, ) -from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs -from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs +from fastdeploy.model_executor.utils import ( + TensorTracker, + default_weight_loader, + free_tensor, + set_weight_attrs, +) class XPUMoEMethod(MoEMethodBase): @@ -63,14 +67,21 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): """ create weight process. """ - if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in ["w16a16", "weight_only_int8"]: - self.up_gate_proj_weight_shape = [layer.num_local_experts,layer.moe_intermediate_size * 2, layer.hidden_size,] + if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in [ + "w16a16", + "weight_only_int8", + ]: + self.up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] self.down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size] layer.up_gate_proj_weight = layer.create_parameter( - shape=self.up_gate_proj_weight_shape, - dtype=layer.weight_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ) + shape=self.up_gate_proj_weight_shape, + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) layer.down_proj_weight = layer.create_parameter( shape=self.down_proj_weight_shape, @@ -85,7 +96,6 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): "weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)), "weight_need_transpose": extra_weight_attrs.get("model_format") == "torch", "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=False), - }, ) set_weight_attrs( @@ -137,8 +147,6 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): layer.hidden_size, ] - - else: self.up_gate_proj_weight_shape = [ layer.num_local_experts, @@ -591,7 +599,9 @@ def process_weights_after_loading(self, layer): weight_list = [] weight_scale_list = [] for expert_id in range(layer.num_local_experts): - quant_weight, scale = weight_quantize_xpu(getattr(layer, unquantized_weight_name)[expert_id].transpose([1, 0]), self.moe_quant_type,-1,-1) + quant_weight, scale = weight_quantize_xpu( + getattr(layer, unquantized_weight_name)[expert_id].transpose([1, 0]), self.moe_quant_type, -1, -1 + ) weight_list.append(quant_weight.transpose([1, 0])) weight_scale_list.append(scale) quanted_weight = paddle.stack(weight_list, axis=0) @@ -599,7 +609,6 @@ def process_weights_after_loading(self, layer): free_tensor(getattr(layer, unquantized_weight_name)) - # create weight setattr( layer, diff --git a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py index 1212dd44af1..ed07c55ada9 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py @@ -17,18 +17,19 @@ import paddle from paddle import nn -from fastdeploy.model_executor.layers.quantization.weight_only import ( - WeightOnlyConfig, - WeightOnlyLinearMethod, -) -from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu from fastdeploy.model_executor.layers.linear import ( MergedColumnParallelLinear, MergedReplicatedLinear, QKVParallelLinear, ) +from fastdeploy.model_executor.layers.quantization.weight_only import ( + WeightOnlyConfig, + WeightOnlyLinearMethod, +) +from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs + class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod): """ Weight only quantization method for linear layer on XPU @@ -108,23 +109,17 @@ def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0])) layer.weight_scale.set_value(weight_scale_tensor) - def process_weights_after_loading(self, layer) -> None: if not self.quant_config.is_checkpoint_bf16: return - quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu( - layer.weight, - self.quant_config.algo, - -1, - -1 - ) + quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(layer.weight, self.quant_config.algo, -1, -1) free_tensor(layer.weight) layer.weight = layer.create_parameter( shape=quanted_weight_tensor.shape[::-1], - dtype="int8" , + dtype="int8", is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) @@ -135,4 +130,4 @@ def process_weights_after_loading(self, layer) -> None: default_initializer=paddle.nn.initializer.Constant(0), ) layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0])) - layer.weight_scale.copy_(weight_scale_tensor, False) \ No newline at end of file + layer.weight_scale.copy_(weight_scale_tensor, False) From eb46c3c6e13effa2fa1c153cca488d48cec88661 Mon Sep 17 00:00:00 2001 From: iosmers Date: Wed, 5 Nov 2025 02:53:14 +0000 Subject: [PATCH 3/3] update --- .../model_executor/layers/backends/xpu/moe/fused_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index 749e5ca7a02..bfd3ca0bc77 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -70,6 +70,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in [ "w16a16", "weight_only_int8", + "weight_only_int4", ]: self.up_gate_proj_weight_shape = [ layer.num_local_experts, @@ -137,7 +138,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): "model_format": extra_weight_attrs.get("model_format", ""), }, ) - if self.moe_quant_type in ["weight_only_int8"]: + if self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]: self.up_gate_proj_scale_shape = [ layer.num_local_experts, layer.moe_intermediate_size * 2, @@ -588,6 +589,8 @@ def process_weights_after_loading(self, layer): # scale scale_name = self.added_scale_attrs[weight_id_map[weight_type]] scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape + if self.moe_quant_type in ["weight_only_int4"]: + weight_shape[-1] //= 2 scale_dtype = "float32" # 2.crate tmp tensor