Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions fastdeploy/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
self.sinks.set_value(sinks_tensor)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
if self.use_qk_norm and ("q_norm" in param.name or "k_norm" in param.name):
loaded_weight = get_tensor(loaded_weight).astype("float32")
param.copy_(loaded_weight, False)
return

loaded_weight = get_tensor(loaded_weight).cast(paddle.get_default_dtype())
if self.quant_method.cache_quant_config.has_zero_point: # cache_int4_zp
loaded_weight = 1.0 / loaded_weight
Expand Down
6 changes: 4 additions & 2 deletions fastdeploy/model_executor/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,12 @@ def weight_loader(self, param, loaded_weight, shard_id=None):
if output_dim == 0:
h2d_copy(param[: shard_weight.shape[0]], shard_weight)
if not current_platform.is_maca():
param[shard_weight.shape[0] :].fill_(0)
if param.shape[0] != shard_weight.shape[0]:
param[shard_weight.shape[0] :].fill_(0)
else:
h2d_copy(param[:, : shard_weight.shape[1]], shard_weight)
param[:, shard_weight.shape[1] :].fill_(0)
if param.shape[1] != shard_weight.shape[1]:
param[:, shard_weight.shape[1] :].fill_(0)

def forward(self, ids_remove_padding=None) -> paddle.Tensor:
"""
Expand Down
36 changes: 21 additions & 15 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,25 +356,31 @@ def __init__(
self.output_sizes = output_sizes

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
assert loaded_shard_id in ["q_a", "kv_a"]
if not param._is_initialized():
param.initialize()
if loaded_shard_id is None:
axis = -1 if (self.fd_config.model_config.model_format == "torch") ^ True else 0
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=0, end=loaded_weight.shape[axis])

if loaded_shard_id == "q_a":
param_shard_offset = 0
param_shard_size = self.output_sizes[0]
else:
# loaded_shard_id == "kv_a"
param_shard_offset = self.output_sizes[0]
param_shard_size = self.output_sizes[1]
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(
param,
(self.fd_config.model_config.model_format == "torch") ^ True,
start=param_shard_offset,
end=param_shard_offset + param_shard_size,
)
assert loaded_shard_id in ["q_a", "kv_a", "gate", "up"]

if loaded_shard_id in ["q_a", "gate"]:
param_shard_offset = 0
param_shard_size = self.output_sizes[0]
elif loaded_shard_id in ["kv_a", "up"]:
param_shard_offset = self.output_sizes[0]
param_shard_size = self.output_sizes[1]

if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(
param,
(self.fd_config.model_config.model_format == "torch") ^ True,
start=param_shard_offset,
end=param_shard_offset + param_shard_size,
)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/model_executor/layers/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def __init__(
},
)
set_weight_attrs(self.linear.weight, {"output_dim": True})
if self.tp_size > 1:
if with_bias:
set_weight_attrs(self.linear.bias, {"output_dim": True})

else:
self.linear = RowParallelLinear(
embedding_dim,
Expand Down
6 changes: 4 additions & 2 deletions fastdeploy/model_executor/layers/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,13 @@ def weight_loader(
if not param._is_initialized():
param.initialize()
weight_need_transpose = getattr(param, "weight_need_transpose", False)

if self.ep_size > 1 or weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)

if shard_id is None:
# 1.gate up fused in disk
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
shard_offsets = [
Expand All @@ -293,7 +296,6 @@ def weight_loader(
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id, "fused")
else:
if weight_need_transpose and source != "fused":
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# 2.gate up splited in disk
assert shard_id in ["gate", "down", "up"]
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/model_executor/layers/mtp_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def __init__(
)
if self.tp_size > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
if self.bias_key is not None:
set_weight_attrs(self.linear.bias, {"output_dim": True})

else:
self.linear = RowParallelLinear(
embedding_dim,
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def init_weight(self):
dtype=self._norm_weight_dtype,
)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight).astype(self._norm_weight_dtype)
param.copy_(loaded_weight, False)

def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def create_weights(self, layer, **extra_weight_attrs):
weight_shape = layer.weight_shape
weight_scale_inv_shape = weight_scale_inv_shape
extra_weight_attrs["output_dim"] = (
not extra_weight_attrs["output_dim"] if extra_weight_attrs["output_dim"] is not None else None
not extra_weight_attrs["output_dim"]
if extra_weight_attrs.get("output_dim", None) is not None
else None
)

layer.weight_dtype = "float8_e4m3fn"
Expand Down
Loading