From 36dbbe27d09af4a6ba8db8b27b8113d33b4476c2 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 9 Dec 2025 21:00:43 +0800 Subject: [PATCH 01/12] [Model] tp+ep support v1_loader --- .../layers/attention/attention.py | 5 ++ fastdeploy/model_executor/layers/linear.py | 52 +++++++++++++++++++ fastdeploy/model_executor/layers/lm_head.py | 3 ++ .../model_executor/layers/mtp_linear.py | 2 + .../model_executor/layers/normalization.py | 4 ++ 5 files changed, 66 insertions(+) diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index 79804aa2d5c..4cad10ec506 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -230,6 +230,11 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): loaded_weight = get_tensor(loaded_weight).cast(paddle.get_default_dtype()) + + if self.use_qk_norm and ("q_norm" in param.name or "k_norm" in param.name): + param.copy_(loaded_weight, False) + return + if self.quant_method.cache_quant_config.has_zero_point: # cache_int4_zp loaded_weight = 1.0 / loaded_weight else: diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 14d1e0dcc0c..4505f81c104 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -25,6 +25,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.model_executor.utils import ( default_weight_loader, + fd_cast, h2d_copy, process_weight_transpose, set_weight_attrs, @@ -901,6 +902,57 @@ def __init__( if self.tp_size > 1 and self.reduce_results: set_weight_attrs(self.bias, {"tp_row_bias": True}) + def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + # In some senerio such as tsp, weight and bias of this layer will not be split in specific module. + # For example, weight and bias of this layer in shared_experts will not split, but will be split in o_proj. + # So, we add a white list to avoid split weight and bias in these layers. + layer_white_list = ["shared_experts"] + layer_in_white_list = any(key in self.prefix for key in layer_white_list) + + output_dim = getattr(param, "output_dim", None) + weight_need_transpose = getattr(param, "weight_need_transpose", False) + if weight_need_transpose: + loaded_weight = loaded_weight.transpose([1, 0]) + # Tensor parallelism splits the weight along the output_dim + if ( + output_dim is not None + and self.fd_config is not None + and self.fd_config.parallel_config.tensor_parallel_size > 1 + ): + dim = -1 if output_dim else 0 + if isinstance(loaded_weight, paddle.Tensor): + size = loaded_weight.shape[dim] + else: + size = loaded_weight.get_shape()[dim] + block_size = size // self.fd_config.parallel_config.tensor_parallel_size + shard_offset = self.fd_config.parallel_config.tensor_parallel_rank * block_size + shard_size = (self.fd_config.parallel_config.tensor_parallel_rank + 1) * block_size + + # when use_sequence_parallel_moe, we don't split. + if layer_in_white_list: + pass + else: + loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size) + + tp_row_bias = getattr(param, "tp_row_bias", None) + if layer_in_white_list: + pass + else: + if tp_row_bias: + loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size + + # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation + loaded_weight = fd_cast(loaded_weight, param) + + if param.shape != loaded_weight.shape: + # for e_score_correction_bias + loaded_weight = loaded_weight.reshape(param.shape) + assert param.shape == loaded_weight.shape, ( + f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" + ) + loaded_weight = get_tensor(loaded_weight) + param.copy_(loaded_weight, False) + def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor: token_num = x.shape[0] token_num_pad = (token_num + self.tp_size - 1) // self.tp_size * self.tp_size diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index ff2797a0415..0bd0a965b71 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -102,6 +102,9 @@ def __init__( }, ) set_weight_attrs(self.linear.weight, {"output_dim": True}) + if with_bias: + set_weight_attrs(self.linear.bias, {"output_dim": True}) + else: self.linear = RowParallelLinear( embedding_dim, diff --git a/fastdeploy/model_executor/layers/mtp_linear.py b/fastdeploy/model_executor/layers/mtp_linear.py index b1699720bdd..23f5a56d482 100644 --- a/fastdeploy/model_executor/layers/mtp_linear.py +++ b/fastdeploy/model_executor/layers/mtp_linear.py @@ -86,6 +86,8 @@ def __init__( ) if self.tp_size > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) + set_weight_attrs(self.linear.bias, {"output_dim": True}) + else: self.linear = RowParallelLinear( embedding_dim, diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index ec1f0e65891..1e37d73bd09 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -130,6 +130,10 @@ def init_weight(self): dtype=self._norm_weight_dtype, ) + def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + loaded_weight = get_tensor(loaded_weight).astype(self._norm_weight_dtype) + param.copy_(loaded_weight, False) + def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): """ Load the checkpoint state dictionary into the layer. From 6f6256f975828b39830449567e65b39fa12ad26e Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Mon, 15 Dec 2025 12:24:29 +0800 Subject: [PATCH 02/12] fix --- fastdeploy/model_executor/layers/linear.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 4505f81c104..19439ad4435 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -903,10 +903,16 @@ def __init__( set_weight_attrs(self.bias, {"tp_row_bias": True}) def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): - # In some senerio such as tsp, weight and bias of this layer will not be split in specific module. + # In tp_size > 1 and ep_size > 1, weight and bias of this layer will not be split in specific module. # For example, weight and bias of this layer in shared_experts will not split, but will be split in o_proj. # So, we add a white list to avoid split weight and bias in these layers. - layer_white_list = ["shared_experts"] + if ( + self.fd_config.parallel_config.tensor_parallel_size > 1 + and self.fd_config.parallel_config.expert_parallel_size > 1 + ): + layer_white_list = ["shared_experts"] + else: + layer_white_list = [] layer_in_white_list = any(key in self.prefix for key in layer_white_list) output_dim = getattr(param, "output_dim", None) From aee0907457bca6627b6e52d4bb203ef031521c5a Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Mon, 15 Dec 2025 14:15:23 +0800 Subject: [PATCH 03/12] fix mtp_linear --- fastdeploy/model_executor/layers/mtp_linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/mtp_linear.py b/fastdeploy/model_executor/layers/mtp_linear.py index 23f5a56d482..e1f52d73899 100644 --- a/fastdeploy/model_executor/layers/mtp_linear.py +++ b/fastdeploy/model_executor/layers/mtp_linear.py @@ -86,7 +86,8 @@ def __init__( ) if self.tp_size > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) - set_weight_attrs(self.linear.bias, {"output_dim": True}) + if self.bias_key is not None: + set_weight_attrs(self.linear.bias, {"output_dim": True}) else: self.linear = RowParallelLinear( From e380ec48fcc1ef005f20aa164c7c4973a1224949 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Mon, 15 Dec 2025 14:25:22 +0800 Subject: [PATCH 04/12] fix mtp_linear --- fastdeploy/model_executor/layers/linear.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 19439ad4435..0dfb9c35f48 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -935,17 +935,12 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N shard_size = (self.fd_config.parallel_config.tensor_parallel_rank + 1) * block_size # when use_sequence_parallel_moe, we don't split. - if layer_in_white_list: - pass - else: + if not layer_in_white_list: loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size) tp_row_bias = getattr(param, "tp_row_bias", None) - if layer_in_white_list: - pass - else: - if tp_row_bias: - loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size + if not layer_in_white_list and tp_row_bias: + loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation loaded_weight = fd_cast(loaded_weight, param) From f0ee4d0ac46a0b3718950a2eb263219735de691c Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Mon, 15 Dec 2025 17:28:37 +0800 Subject: [PATCH 05/12] fix --- .../layers/attention/attention.py | 4 +- .../model_executor/load_weight_utils.py | 55 ++++++++++++++----- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index 4cad10ec506..a5ac1876e34 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -229,12 +229,12 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): self.sinks.set_value(sinks_tensor) def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): - loaded_weight = get_tensor(loaded_weight).cast(paddle.get_default_dtype()) - if self.use_qk_norm and ("q_norm" in param.name or "k_norm" in param.name): + loaded_weight = get_tensor(loaded_weight).astype("float32") param.copy_(loaded_weight, False) return + loaded_weight = get_tensor(loaded_weight).cast(paddle.get_default_dtype()) if self.quant_method.cache_quant_config.has_zero_point: # cache_int4_zp loaded_weight = 1.0 / loaded_weight else: diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index a795a9e0304..09dbfbd0181 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -71,9 +71,9 @@ def load_weights_from_cache(model, weights_iterator): def get_weight_iterator(model_path: str): - _, files_list, use_safetensors = get_all_weights_file(model_path) + key_name_list, files_list, use_safetensors = get_all_weights_file(model_path) if use_safetensors: - weights_iterator = safetensors_weights_iterator(files_list) + weights_iterator = safetensors_weights_iterator(key_name_list, files_list) else: weights_iterator = pdparams_weight_iterator(files_list) return weights_iterator @@ -319,18 +319,35 @@ def get_expert_ranges(fd_config): return state_dict -def safetensors_weights_iterator(safe_tensor_list: list[str]): +class SafetensorFileCache: + def __init__(self): + self._files = {} + + def get(self, filename): + if filename not in self._files: + self._files[filename] = safe_open(filename, framework="paddle", device="cpu") + return self._files[filename] + + def close(self): + for f in self._files.values(): + f.__exit__(None, None, None) + self._files.clear() + + +def safetensors_weights_iterator(key_name_list: list[str], safe_tensor_list: list[str]): """ safetensors_weights_iterator """ - for st_file in tqdm( - safe_tensor_list, - desc="Loading safetensors checkpoint shards", + + safe_tensor_cache = SafetensorFileCache() + for i, key_name in tqdm( + enumerate(key_name_list), + total=len(key_name_list), + desc="Loading weights", ): - with safe_open(st_file, framework="paddle", device="cpu") as f: - for name in f.keys(): - param = f.get_tensor(name) - yield name, param + f = safe_tensor_cache.get(safe_tensor_list[i]) + param = f.get_tensor(key_name) + yield key_name, param def fast_weights_iterator(safe_tensor_list: list[str]): @@ -360,10 +377,18 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int): return state_dict +def natural_key(s: str): + import re + + return [int(text) if text.isdigit() else text for text in re.split(r"(\d+)", s)] + + def get_all_weights_file(model_path: str): """ get_all_safetensors """ + from collections import OrderedDict + model_path = Path(model_path) use_safetensors = True files_list = [str(file) for file in model_path.glob("*.pdparams") if file.name != "scheduler.pdparams"] @@ -373,17 +398,19 @@ def get_all_weights_file(model_path: str): else: safe_model_path = model_path / "model.safetensors" if safe_model_path.exists(): - files_list = [str(safe_model_path)] with safe_open(safe_model_path, framework="np", device="cpu") as f: key_name_list = f.keys() + + files_list = [str(safe_model_path)] * len(key_name_list) return key_name_list, files_list, use_safetensors else: index_file = model_path / "model.safetensors.index.json" with index_file.open("r") as f: weight_map = json.load(f)["weight_map"] - weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map} - key_name_list = list(weight_map.keys()) - files_list = sorted(weight_files_in_index) + sorted_weight_map = OrderedDict(sorted(weight_map.items(), key=lambda kv: natural_key(kv[0]))) + files_list = [str(model_path / file_name) for (_, file_name) in sorted_weight_map.items()] + key_name_list = list(sorted_weight_map.keys()) + return key_name_list, files_list, use_safetensors From 45b03519e2f57291c5a57c6abd56145d4aa262d9 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Mon, 15 Dec 2025 20:49:56 +0800 Subject: [PATCH 06/12] fix --- fastdeploy/model_executor/layers/embeddings.py | 6 ++++-- fastdeploy/model_executor/layers/lm_head.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 52d7dadeebc..5ae82efe4ca 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -283,10 +283,12 @@ def weight_loader(self, param, loaded_weight, shard_id=None): if output_dim == 0: h2d_copy(param[: shard_weight.shape[0]], shard_weight) if not current_platform.is_maca(): - param[shard_weight.shape[0] :].fill_(0) + if param.shape[0] != shard_weight.shape[0]: + param[shard_weight.shape[0] :].fill_(0) else: h2d_copy(param[:, : shard_weight.shape[1]], shard_weight) - param[:, shard_weight.shape[1] :].fill_(0) + if param.shape[1] != shard_weight.shape[1]: + param[:, shard_weight.shape[1] :].fill_(0) def forward(self, ids_remove_padding=None) -> paddle.Tensor: """ diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index 0bd0a965b71..a7bff3905b0 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -102,8 +102,9 @@ def __init__( }, ) set_weight_attrs(self.linear.weight, {"output_dim": True}) - if with_bias: - set_weight_attrs(self.linear.bias, {"output_dim": True}) + if self.tp_size > 1: + if with_bias: + set_weight_attrs(self.linear.bias, {"output_dim": True}) else: self.linear = RowParallelLinear( From 3923b6d2e30ef44f89badd11ffa27382038110ae Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 16 Dec 2025 11:19:10 +0800 Subject: [PATCH 07/12] fix v0 loader --- fastdeploy/model_executor/load_weight_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 09dbfbd0181..a94f2786eb9 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -370,8 +370,8 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int): """ state_dict = {} - _, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}")) - weights_iterator = safetensors_weights_iterator(safetensor_files) + key_name_list, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}")) + weights_iterator = safetensors_weights_iterator(key_name_list, safetensor_files) for name, weight in weights_iterator: state_dict[name] = weight.clone() return state_dict From 2bb4bf60acdc349636f49d1ad6f53782fdbdec14 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 16 Dec 2025 18:16:43 +0800 Subject: [PATCH 08/12] fix --- .../model_executor/load_weight_utils.py | 59 +++++-------------- 1 file changed, 16 insertions(+), 43 deletions(-) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index a94f2786eb9..a795a9e0304 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -71,9 +71,9 @@ def load_weights_from_cache(model, weights_iterator): def get_weight_iterator(model_path: str): - key_name_list, files_list, use_safetensors = get_all_weights_file(model_path) + _, files_list, use_safetensors = get_all_weights_file(model_path) if use_safetensors: - weights_iterator = safetensors_weights_iterator(key_name_list, files_list) + weights_iterator = safetensors_weights_iterator(files_list) else: weights_iterator = pdparams_weight_iterator(files_list) return weights_iterator @@ -319,35 +319,18 @@ def get_expert_ranges(fd_config): return state_dict -class SafetensorFileCache: - def __init__(self): - self._files = {} - - def get(self, filename): - if filename not in self._files: - self._files[filename] = safe_open(filename, framework="paddle", device="cpu") - return self._files[filename] - - def close(self): - for f in self._files.values(): - f.__exit__(None, None, None) - self._files.clear() - - -def safetensors_weights_iterator(key_name_list: list[str], safe_tensor_list: list[str]): +def safetensors_weights_iterator(safe_tensor_list: list[str]): """ safetensors_weights_iterator """ - - safe_tensor_cache = SafetensorFileCache() - for i, key_name in tqdm( - enumerate(key_name_list), - total=len(key_name_list), - desc="Loading weights", + for st_file in tqdm( + safe_tensor_list, + desc="Loading safetensors checkpoint shards", ): - f = safe_tensor_cache.get(safe_tensor_list[i]) - param = f.get_tensor(key_name) - yield key_name, param + with safe_open(st_file, framework="paddle", device="cpu") as f: + for name in f.keys(): + param = f.get_tensor(name) + yield name, param def fast_weights_iterator(safe_tensor_list: list[str]): @@ -370,25 +353,17 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int): """ state_dict = {} - key_name_list, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}")) - weights_iterator = safetensors_weights_iterator(key_name_list, safetensor_files) + _, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}")) + weights_iterator = safetensors_weights_iterator(safetensor_files) for name, weight in weights_iterator: state_dict[name] = weight.clone() return state_dict -def natural_key(s: str): - import re - - return [int(text) if text.isdigit() else text for text in re.split(r"(\d+)", s)] - - def get_all_weights_file(model_path: str): """ get_all_safetensors """ - from collections import OrderedDict - model_path = Path(model_path) use_safetensors = True files_list = [str(file) for file in model_path.glob("*.pdparams") if file.name != "scheduler.pdparams"] @@ -398,19 +373,17 @@ def get_all_weights_file(model_path: str): else: safe_model_path = model_path / "model.safetensors" if safe_model_path.exists(): + files_list = [str(safe_model_path)] with safe_open(safe_model_path, framework="np", device="cpu") as f: key_name_list = f.keys() - - files_list = [str(safe_model_path)] * len(key_name_list) return key_name_list, files_list, use_safetensors else: index_file = model_path / "model.safetensors.index.json" with index_file.open("r") as f: weight_map = json.load(f)["weight_map"] - sorted_weight_map = OrderedDict(sorted(weight_map.items(), key=lambda kv: natural_key(kv[0]))) - files_list = [str(model_path / file_name) for (_, file_name) in sorted_weight_map.items()] - key_name_list = list(sorted_weight_map.keys()) - + weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map} + key_name_list = list(weight_map.keys()) + files_list = sorted(weight_files_in_index) return key_name_list, files_list, use_safetensors From adbb24b29e7d98dbccef4c6069cdc9c11a68f3df Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Wed, 17 Dec 2025 14:09:14 +0800 Subject: [PATCH 09/12] Add get_tensor for EP --- fastdeploy/model_executor/layers/moe/moe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 5b1be52d183..11725729a9b 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -274,10 +274,13 @@ def weight_loader( if not param._is_initialized(): param.initialize() weight_need_transpose = getattr(param, "weight_need_transpose", False) + + if self.ep_size > 1 or weight_need_transpose: + loaded_weight = get_tensor(loaded_weight) + if shard_id is None: # 1.gate up fused in disk if weight_need_transpose: - loaded_weight = get_tensor(loaded_weight) loaded_weight = loaded_weight.transpose([1, 0]) output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]] shard_offsets = [ @@ -293,7 +296,6 @@ def weight_loader( self.weight_loader(param, loaded_weight_shard, expert_id, shard_id, "fused") else: if weight_need_transpose and source != "fused": - loaded_weight = get_tensor(loaded_weight) loaded_weight = loaded_weight.transpose([1, 0]) # 2.gate up splited in disk assert shard_id in ["gate", "down", "up"] From f9df70922c166ae0736a6962ea48980af0b7c64e Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Wed, 17 Dec 2025 20:12:22 +0800 Subject: [PATCH 10/12] fix linear weight_loader --- fastdeploy/model_executor/layers/linear.py | 139 +++++++++--------- .../layers/quantization/block_wise_fp8.py | 4 +- 2 files changed, 75 insertions(+), 68 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 0dfb9c35f48..fb24cdd9709 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -25,7 +25,6 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.model_executor.utils import ( default_weight_loader, - fd_cast, h2d_copy, process_weight_transpose, set_weight_attrs, @@ -357,25 +356,31 @@ def __init__( self.output_sizes = output_sizes def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): - assert loaded_shard_id in ["q_a", "kv_a"] if not param._is_initialized(): param.initialize() + if loaded_shard_id is None: + axis = -1 if (self.fd_config.model_config.model_format == "torch") ^ True else 0 + if hasattr(param, "tensor_track"): + param.tensor_track.mark(start=0, end=loaded_weight.shape[axis]) - if loaded_shard_id == "q_a": - param_shard_offset = 0 - param_shard_size = self.output_sizes[0] else: - # loaded_shard_id == "kv_a" - param_shard_offset = self.output_sizes[0] - param_shard_size = self.output_sizes[1] - if hasattr(param, "tensor_track"): - param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size) - param = slice_fn( - param, - (self.fd_config.model_config.model_format == "torch") ^ True, - start=param_shard_offset, - end=param_shard_offset + param_shard_size, - ) + assert loaded_shard_id in ["q_a", "kv_a", "gate", "up"] + + if loaded_shard_id == "q_a" or "gate": + param_shard_offset = 0 + param_shard_size = self.output_sizes[0] + elif loaded_shard_id == "kv_a" or "up": + param_shard_offset = self.output_sizes[0] + param_shard_size = self.output_sizes[1] + + if hasattr(param, "tensor_track"): + param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size) + param = slice_fn( + param, + (self.fd_config.model_config.model_format == "torch") ^ True, + start=param_shard_offset, + end=param_shard_offset + param_shard_size, + ) assert param.shape == loaded_weight.shape, ( f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" ) @@ -902,57 +907,57 @@ def __init__( if self.tp_size > 1 and self.reduce_results: set_weight_attrs(self.bias, {"tp_row_bias": True}) - def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): - # In tp_size > 1 and ep_size > 1, weight and bias of this layer will not be split in specific module. - # For example, weight and bias of this layer in shared_experts will not split, but will be split in o_proj. - # So, we add a white list to avoid split weight and bias in these layers. - if ( - self.fd_config.parallel_config.tensor_parallel_size > 1 - and self.fd_config.parallel_config.expert_parallel_size > 1 - ): - layer_white_list = ["shared_experts"] - else: - layer_white_list = [] - layer_in_white_list = any(key in self.prefix for key in layer_white_list) - - output_dim = getattr(param, "output_dim", None) - weight_need_transpose = getattr(param, "weight_need_transpose", False) - if weight_need_transpose: - loaded_weight = loaded_weight.transpose([1, 0]) - # Tensor parallelism splits the weight along the output_dim - if ( - output_dim is not None - and self.fd_config is not None - and self.fd_config.parallel_config.tensor_parallel_size > 1 - ): - dim = -1 if output_dim else 0 - if isinstance(loaded_weight, paddle.Tensor): - size = loaded_weight.shape[dim] - else: - size = loaded_weight.get_shape()[dim] - block_size = size // self.fd_config.parallel_config.tensor_parallel_size - shard_offset = self.fd_config.parallel_config.tensor_parallel_rank * block_size - shard_size = (self.fd_config.parallel_config.tensor_parallel_rank + 1) * block_size - - # when use_sequence_parallel_moe, we don't split. - if not layer_in_white_list: - loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size) - - tp_row_bias = getattr(param, "tp_row_bias", None) - if not layer_in_white_list and tp_row_bias: - loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size - - # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation - loaded_weight = fd_cast(loaded_weight, param) - - if param.shape != loaded_weight.shape: - # for e_score_correction_bias - loaded_weight = loaded_weight.reshape(param.shape) - assert param.shape == loaded_weight.shape, ( - f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" - ) - loaded_weight = get_tensor(loaded_weight) - param.copy_(loaded_weight, False) + # def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + # # In tp_size > 1 and ep_size > 1, weight and bias of this layer will not be split in specific module. + # # For example, weight and bias of this layer in shared_experts will not split, but will be split in o_proj. + # # So, we add a white list to avoid split weight and bias in these layers. + # if ( + # self.fd_config.parallel_config.tensor_parallel_size > 1 + # and self.fd_config.parallel_config.expert_parallel_size > 1 + # ): + # layer_white_list = ["shared_experts"] + # else: + # layer_white_list = [] + # layer_in_white_list = any(key in self.prefix for key in layer_white_list) + + # output_dim = getattr(param, "output_dim", None) + # weight_need_transpose = getattr(param, "weight_need_transpose", False) + # if weight_need_transpose: + # loaded_weight = loaded_weight.transpose([1, 0]) + # # Tensor parallelism splits the weight along the output_dim + # if ( + # output_dim is not None + # and self.fd_config is not None + # and self.fd_config.parallel_config.tensor_parallel_size > 1 + # ): + # dim = -1 if output_dim else 0 + # if isinstance(loaded_weight, paddle.Tensor): + # size = loaded_weight.shape[dim] + # else: + # size = loaded_weight.get_shape()[dim] + # block_size = size // self.fd_config.parallel_config.tensor_parallel_size + # shard_offset = self.fd_config.parallel_config.tensor_parallel_rank * block_size + # shard_size = (self.fd_config.parallel_config.tensor_parallel_rank + 1) * block_size + + # # when use_sequence_parallel_moe, we don't split. + # if not layer_in_white_list: + # loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size) + + # tp_row_bias = getattr(param, "tp_row_bias", None) + # if not layer_in_white_list and tp_row_bias: + # loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size + + # # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation + # loaded_weight = fd_cast(loaded_weight, param) + + # if param.shape != loaded_weight.shape: + # # for e_score_correction_bias + # loaded_weight = loaded_weight.reshape(param.shape) + # assert param.shape == loaded_weight.shape, ( + # f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" + # ) + # loaded_weight = get_tensor(loaded_weight) + # param.copy_(loaded_weight, False) def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor: token_num = x.shape[0] diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index a7b61fc0ef8..59daa238480 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -138,7 +138,9 @@ def create_weights(self, layer, **extra_weight_attrs): weight_shape = layer.weight_shape weight_scale_inv_shape = weight_scale_inv_shape extra_weight_attrs["output_dim"] = ( - not extra_weight_attrs["output_dim"] if extra_weight_attrs["output_dim"] is not None else None + not extra_weight_attrs["output_dim"] + if extra_weight_attrs.get("output_dim", None) is not None + else None ) layer.weight_dtype = "float8_e4m3fn" From 65efe2df73d3736cca4fae222d53388c7c4561f4 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Wed, 17 Dec 2025 20:31:06 +0800 Subject: [PATCH 11/12] fix typo --- fastdeploy/model_executor/layers/linear.py | 52 ---------------------- 1 file changed, 52 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index fb24cdd9709..a3e6e7ed99b 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -907,58 +907,6 @@ def __init__( if self.tp_size > 1 and self.reduce_results: set_weight_attrs(self.bias, {"tp_row_bias": True}) - # def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): - # # In tp_size > 1 and ep_size > 1, weight and bias of this layer will not be split in specific module. - # # For example, weight and bias of this layer in shared_experts will not split, but will be split in o_proj. - # # So, we add a white list to avoid split weight and bias in these layers. - # if ( - # self.fd_config.parallel_config.tensor_parallel_size > 1 - # and self.fd_config.parallel_config.expert_parallel_size > 1 - # ): - # layer_white_list = ["shared_experts"] - # else: - # layer_white_list = [] - # layer_in_white_list = any(key in self.prefix for key in layer_white_list) - - # output_dim = getattr(param, "output_dim", None) - # weight_need_transpose = getattr(param, "weight_need_transpose", False) - # if weight_need_transpose: - # loaded_weight = loaded_weight.transpose([1, 0]) - # # Tensor parallelism splits the weight along the output_dim - # if ( - # output_dim is not None - # and self.fd_config is not None - # and self.fd_config.parallel_config.tensor_parallel_size > 1 - # ): - # dim = -1 if output_dim else 0 - # if isinstance(loaded_weight, paddle.Tensor): - # size = loaded_weight.shape[dim] - # else: - # size = loaded_weight.get_shape()[dim] - # block_size = size // self.fd_config.parallel_config.tensor_parallel_size - # shard_offset = self.fd_config.parallel_config.tensor_parallel_rank * block_size - # shard_size = (self.fd_config.parallel_config.tensor_parallel_rank + 1) * block_size - - # # when use_sequence_parallel_moe, we don't split. - # if not layer_in_white_list: - # loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size) - - # tp_row_bias = getattr(param, "tp_row_bias", None) - # if not layer_in_white_list and tp_row_bias: - # loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size - - # # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation - # loaded_weight = fd_cast(loaded_weight, param) - - # if param.shape != loaded_weight.shape: - # # for e_score_correction_bias - # loaded_weight = loaded_weight.reshape(param.shape) - # assert param.shape == loaded_weight.shape, ( - # f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" - # ) - # loaded_weight = get_tensor(loaded_weight) - # param.copy_(loaded_weight, False) - def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor: token_num = x.shape[0] token_num_pad = (token_num + self.tp_size - 1) // self.tp_size * self.tp_size From f52d7d8ea30563c5f1c11d6f4558bac3bf57a5a7 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Thu, 18 Dec 2025 11:30:04 +0800 Subject: [PATCH 12/12] fix --- fastdeploy/model_executor/layers/linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index a3e6e7ed99b..49b25dc3d0c 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -366,10 +366,10 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N else: assert loaded_shard_id in ["q_a", "kv_a", "gate", "up"] - if loaded_shard_id == "q_a" or "gate": + if loaded_shard_id in ["q_a", "gate"]: param_shard_offset = 0 param_shard_size = self.output_sizes[0] - elif loaded_shard_id == "kv_a" or "up": + elif loaded_shard_id in ["kv_a", "up"]: param_shard_offset = self.output_sizes[0] param_shard_size = self.output_sizes[1]