Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 105 additions & 5 deletions fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
weight_quantize_xpu,
xpu_moe_layer,
)
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
from fastdeploy.model_executor.utils import (
TensorTracker,
default_weight_loader,
free_tensor,
set_weight_attrs,
)


class XPUMoEMethod(MoEMethodBase):
Expand Down Expand Up @@ -62,15 +67,17 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
create weight process.
"""
if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in ["w16a16"]:
if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in [
"w16a16",
"weight_only_int8",
"weight_only_int4",
]:
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
self.down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size]
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}

layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
Expand All @@ -86,18 +93,21 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
set_weight_attrs(
layer.up_gate_proj_weight,
{
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=False),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=True),
},
)

if layer.with_bias:
layer.up_gate_proj_bias = layer.create_parameter(
shape=[layer.num_experts, layer.moe_intermediate_size * 2],
Expand Down Expand Up @@ -128,6 +138,15 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"model_format": extra_weight_attrs.get("model_format", ""),
},
)
if self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
self.up_gate_proj_scale_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
]
self.down_proj_scale_shape = [
layer.num_local_experts,
layer.hidden_size,
]

else:
self.up_gate_proj_weight_shape = [
Expand Down Expand Up @@ -536,6 +555,87 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
getattr(layer, scale_name).set_value(quanted_weight_scale)

def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
else:
weight_type = "down"

# 1.init shape and type
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
if weight_type == "gate_up":
weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
else:
weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
weight_dtype = "int8"
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
if self.moe_quant_type in ["weight_only_int4"]:
weight_shape[-1] //= 2
scale_dtype = "float32"

# 2.crate tmp tensor

# weight = paddle.empty(weight_shape, dtype=weight_dtype)
# scale = paddle.empty(scale_shape, dtype=scale_dtype)

# 3.quantize weight
weight_list = []
weight_scale_list = []
for expert_id in range(layer.num_local_experts):
quant_weight, scale = weight_quantize_xpu(
getattr(layer, unquantized_weight_name)[expert_id].transpose([1, 0]), self.moe_quant_type, -1, -1
)
weight_list.append(quant_weight.transpose([1, 0]))
weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0)
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)

free_tensor(getattr(layer, unquantized_weight_name))

# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)

getattr(layer, weight_name).set_value(quanted_weight)
getattr(layer, scale_name).set_value(quanted_weight_scale)


class XPUW4A8MoEMethod(XPUMoEMethod):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@
import paddle
from paddle import nn

from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.quantization.weight_only import (
WeightOnlyConfig,
WeightOnlyLinearMethod,
)
from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs


class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
Expand All @@ -41,22 +47,48 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs) -> None:
Create weights for linear layer on XPU
"""
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
weight_scale_shape = [layer.weight_shape[1]]
layer.weight_shape.reverse()
if self.quant_config.name() == "weight_only_int4":
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype="float32",
is_bias=False,
)
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
quant_attrs = extra_weight_attrs
if (
isinstance(layer, MergedColumnParallelLinear)
or isinstance(layer, QKVParallelLinear)
or isinstance(layer, MergedReplicatedLinear)
):
quant_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim", True)
),
}
set_weight_attrs(
layer.weight,
quant_attrs,
)
else:
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
weight_scale_shape = [layer.weight_shape[1]]
layer.weight_shape.reverse()
if self.quant_config.name() == "weight_only_int4":
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype="float32",
is_bias=False,
)

def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None:
"""
Expand All @@ -76,3 +108,26 @@ def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None
weight_scale_tensor = paddle.concat(weight_scale_tensors, axis=0)
layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0]))
layer.weight_scale.set_value(weight_scale_tensor)

def process_weights_after_loading(self, layer) -> None:
if not self.quant_config.is_checkpoint_bf16:
return

quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(layer.weight, self.quant_config.algo, -1, -1)

free_tensor(layer.weight)

layer.weight = layer.create_parameter(
shape=quanted_weight_tensor.shape[::-1],
dtype="int8",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale_tensor.shape,
dtype=weight_scale_tensor.dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0]))
layer.weight_scale.copy_(weight_scale_tensor, False)
Loading