Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions fastdeploy/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from safetensors import safe_open

from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import default_weight_loader


class Attention(nn.Layer):
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
ValueError: If the `v_head_dim` is less than 0.
"""
super().__init__()
self.fd_config = fd_config
self.num_heads: int = (
fd_config.model_config.num_attention_heads // fd_config.parallel_config.tensor_parallel_size
)
Expand All @@ -101,23 +103,21 @@ def __init__(
self.use_neox_rotary_style: bool = use_neox_rotary_style

if fd_config.quant_config and hasattr(fd_config.quant_config, "kv_cache_quant_type"):
self.kvcache_quant_method: QuantMethodBase = fd_config.quant_config.get_quant_method(self)
self.quant_method: QuantMethodBase = fd_config.quant_config.get_quant_method(self)
else:
self.kvcache_quant_method = None
self.quant_method = None

if self.kvcache_quant_method is None:
if self.quant_method is None:
logger.info(f"Attention is running in cache kv {self._dtype} mode")
else:
logger.info(
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
)
logger.info(f"Attention is running in cache kv {self.quant_method.cache_quant_config.quant_type} mode")
self.use_qk_norm = use_qk_norm
self.rms_norm_eps = rms_norm_eps
if self.use_qk_norm:
self.q_norm_key = f"{self.prefix}.q_norm"
self.k_norm_key = f"{self.prefix}.k_norm"
self.init_weight()

self.init_weight()
if (
fd_config.moba_attention_config is not None
and fd_config.moba_attention_config.moba_encoder_top_k_left is not None
Expand Down Expand Up @@ -161,32 +161,50 @@ def __init__(
)

def init_weight(self):
self.q_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.quant_method is not None:
self.quant_method.create_weights(
self,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
)

self.k_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.use_qk_norm:
self.q_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)

self.k_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)

def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
Attention only have quant related scales not other parameters.
"""
if self.kvcache_quant_method is not None:
self.kvcache_quant_method.create_weights(self, state_dict)
if self.quant_method is not None:
self.quant_method.process_loaded_weights(self, state_dict)
if self.use_qk_norm:
q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight")))
k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight")))
self.q_norm_weight.set_value(q_norm_weight_tensor.astype("float32"))
self.k_norm_weight.set_value(k_norm_weight_tensor.astype("float32"))

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight).cast(paddle.get_default_dtype())
if self.quant_method.cache_quant_config.has_zero_point: # cache_int4_zp
loaded_weight = 1.0 / loaded_weight
else:
loaded_weight = self.quant_method.cache_quant_config.max_bound / loaded_weight

param.copy_(loaded_weight, False)

def forward(
self,
q: paddle.Tensor = None,
Expand Down
103 changes: 88 additions & 15 deletions fastdeploy/model_executor/layers/quantization/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from paddle import nn

from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import set_weight_attrs

from ..utils import create_and_set_parameter
from .quant_base import QuantConfigBase, QuantMethodBase


Expand Down Expand Up @@ -117,9 +117,8 @@ def load_zp(self, layer: nn.Layer, state_dict):
"""
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)).cast(paddle.get_default_dtype())
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)).cast(paddle.get_default_dtype())

create_and_set_parameter(layer, "cache_k_zp", cache_k_zeropoint)
create_and_set_parameter(layer, "cache_v_zp", cache_v_zeropoint)
layer.cache_k_zp.set_value(cache_k_zeropoint)
layer.cache_v_zp.set_value(cache_v_zeropoint)

def load_scale(self, layer: nn.Layer, state_dict):
"""
Expand Down Expand Up @@ -156,21 +155,15 @@ def load_scale(self, layer: nn.Layer, state_dict):
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound
cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound

create_and_set_parameter(layer, "cache_k_scale", cache_k_scale)
create_and_set_parameter(layer, "cache_v_scale", cache_v_scale)
create_and_set_parameter(layer, "cache_k_out_scale", cache_k_out_scale)
create_and_set_parameter(layer, "cache_v_out_scale", cache_v_out_scale)
layer.cache_k_scale.set_value(cache_k_scale)
layer.cache_v_scale.set_value(cache_v_scale)
layer.cache_k_out_scale.set_value(cache_k_out_scale)
layer.cache_v_out_scale.set_value(cache_v_out_scale)

def create_weights(self, layer: nn.Layer, state_dict):
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
create_weights
"""
self.prefix = layer.prefix
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_scale"
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_scale"
self.cache_k_zp_name = layer.prefix + ".cachek_matmul.activation_zero_point"
self.cache_v_zp_name = layer.prefix + ".cachev_matmul.activation_zero_point"

if self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT8:
layer.cache_quant_type_str = "cache_int8"
layer.quant_max_bound = 127.0
Expand All @@ -190,11 +183,91 @@ def create_weights(self, layer: nn.Layer, state_dict):
else:
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")

scale_shape = [layer.fd_config.model_config.num_key_value_heads]
if self.cache_quant_config.is_channel_wise:
scale_shape = [layer.fd_config.model_config.num_key_value_heads, layer.head_dim]

layer.cache_k_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.cache_v_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.cache_k_scale,
{
**extra_weight_attrs,
},
)
set_weight_attrs(
layer.cache_v_scale,
{
**extra_weight_attrs,
},
)
layer.cache_k_out_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.cache_v_out_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)

if self.cache_quant_config.has_zero_point:
layer.cache_k_zp = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.cache_v_zp = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.cache_k_zp,
{
**extra_weight_attrs,
},
)
set_weight_attrs(
layer.cache_v_zp,
{
**extra_weight_attrs,
},
)

def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
use for loader v0
"""
self.prefix = layer.prefix
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_scale"
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_scale"
self.cache_k_zp_name = layer.prefix + ".cachek_matmul.activation_zero_point"
self.cache_v_zp_name = layer.prefix + ".cachev_matmul.activation_zero_point"

if "block_wise" not in layer.cache_quant_type_str:
self.load_scale(layer, state_dict)
if self.cache_quant_config.has_zero_point:
self.load_zp(layer, state_dict)

def process_weights_after_loading(self, layer: nn.Layer):
"""
use for loader v1
"""
if layer.cache_k_scale._is_initialized():
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale)
if layer.cache_v_scale._is_initialized():
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)

def apply(self, layer):
"""
apply
Expand Down
9 changes: 8 additions & 1 deletion fastdeploy/model_executor/models/ernie4_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,10 @@ def load_weights(self, weights_iterator) -> None:
("qkv_proj", "v_proj", None, "v"),
("up_gate_proj", "gate_proj", None, "gate"),
("up_gate_proj", "up_proj", None, "up"),
("attn.cache_k_scale", "cachek_matmul.activation_scale", None, None),
("attn.cache_v_scale", "cachev_matmul.activation_scale", None, None),
("attn.cache_k_zp", "cachek_matmul.activation_zero_point", None, None),
("attn.cache_v_zp", "cachev_matmul.activation_zero_point", None, None),
]

expert_params_mapping = []
Expand All @@ -563,6 +567,7 @@ def load_weights(self, weights_iterator) -> None:
all_param_mapping = general_params_mapping + expert_params_mapping

params_dict = dict(self.named_parameters())

process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))

for loaded_weight_name, loaded_weight in weights_iterator:
Expand Down Expand Up @@ -591,7 +596,9 @@ def load_weights(self, weights_iterator) -> None:
else:
weight_loader(param, loaded_weight, shard_id)

model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
model_sublayer_name = re.sub(
r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name
)
process_weights_after_loading_fn(model_sublayer_name, param)

if self.tie_word_embeddings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,10 @@ def load_weights(self, weights_iterator) -> None:
("resampler_model", "ernie.resampler_model", None, None),
("vision_model", "ernie.vision_model", None, None),
("gate_correction_bias", "moe_statics.e_score_correction_bias", None, None),
("attn.cache_k_scale", "cachek_matmul.activation_scale", None, None),
("attn.cache_v_scale", "cachev_matmul.activation_scale", None, None),
("attn.cache_k_zp", "cachek_matmul.activation_zero_point", None, None),
("attn.cache_v_zp", "cachev_matmul.activation_zero_point", None, None),
# for torch model
("resampler_model", "model.resampler_model", None, None),
("qkv_proj", "q_proj", None, "q"),
Expand Down Expand Up @@ -679,7 +683,9 @@ def load_weights(self, weights_iterator) -> None:
weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
else:
weight_loader(param, loaded_weight, shard_id)
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
model_sublayer_name = re.sub(
r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name
)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
# because we use lazy guard and is not initialized by default
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,10 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:

if quantization_config is not None:
quant_config_name = quantization_config["quantization"]
# TODO(YuanRisheng) is_checkpoint_bf16 may need to be removed and replaced by is_quantized in future
if "kv_cache_quant_type" in quantization_config and load_config.load_choices == "default_v1":
quantization_config["is_checkpoint_bf16"] = True

elif args.quantization != "None":
quantization_config = {}
quant_config_name = args.quantization
Expand Down
15 changes: 0 additions & 15 deletions tests/layers/test_quant_layer.py

This file was deleted.

Loading
Loading