From 7716866f6df8ba4dbe5e55ceb18ff24607fe5d13 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Mon, 22 Sep 2025 20:33:37 +0800 Subject: [PATCH 01/10] support qwen3-embedding --- .../model_executor/layers/embeddings.py | 178 +++++++++++++++++- fastdeploy/model_executor/layers/lm_head.py | 9 + fastdeploy/model_executor/layers/utils.py | 19 ++ tests/pooling/test_embedding.py | 157 +++++++++------ 4 files changed, 300 insertions(+), 63 deletions(-) diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 43fbd76a848..7b97c53e5b1 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -14,6 +14,7 @@ # limitations under the License. """ +from dataclasses import dataclass from typing import Dict import numpy as np @@ -22,9 +23,73 @@ from paddle.distributed import fleet from fastdeploy.config import FDConfig -from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.model_executor.utils import set_weight_attrs, slice_fn -from .utils import get_tensor +from .utils import ( + DEFAULT_VOCAB_PADDING_SIZE, + get_tensor, + pad_vocab_size, + vocab_range_from_global_vocab_size, +) + + +@dataclass +class VocabParallelEmbeddingShardIndices: + """Indices for a shard of a vocab parallel embedding.""" + + padded_org_vocab_start_index: int + padded_org_vocab_end_index: int + padded_added_vocab_start_index: int + padded_added_vocab_end_index: int + + org_vocab_start_index: int + org_vocab_end_index: int + added_vocab_start_index: int + added_vocab_end_index: int + + @property + def num_org_elements(self) -> int: + return self.org_vocab_end_index - self.org_vocab_start_index + + @property + def num_added_elements(self) -> int: + return self.added_vocab_end_index - self.added_vocab_start_index + + @property + def num_org_elements_padded(self) -> int: + return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index + + @property + def num_added_elements_padded(self) -> int: + return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index + + @property + def num_org_vocab_padding(self) -> int: + return self.num_org_elements_padded - self.num_org_elements + + @property + def num_added_vocab_padding(self) -> int: + return self.num_added_elements_padded - self.num_added_elements + + @property + def num_elements_padded(self) -> int: + return self.num_org_elements_padded + self.num_added_elements_padded + + def __post_init__(self): + # sanity checks + assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index + assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index + + assert self.org_vocab_start_index <= self.org_vocab_end_index + assert self.added_vocab_start_index <= self.added_vocab_end_index + + assert self.org_vocab_start_index <= self.padded_org_vocab_start_index + assert self.added_vocab_start_index <= self.padded_added_vocab_start_index + assert self.org_vocab_end_index <= self.padded_org_vocab_end_index + assert self.added_vocab_end_index <= self.padded_added_vocab_end_index + + assert self.num_org_elements <= self.num_org_elements_padded + assert self.num_added_elements <= self.num_added_elements_padded class VocabParallelEmbedding(nn.Layer): @@ -39,6 +104,7 @@ def __init__( embedding_dim: int = 768, params_dtype: str = "bfloat16", prefix="", + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, ) -> None: """ Initialize the VocabParallelEmbedding layer for the model. @@ -65,10 +131,32 @@ def __init__( self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings self.params_dtype: str = params_dtype + self.padding_size = padding_size + + self.org_vocab_size = num_embeddings + self.num_embeddings = num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, self.padding_size + ) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + self.shard_indices = self._get_indices( + self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + self.tensor_parallel_rank, + self.world_size, + ) + + if num_embeddings % self.world_size != 0: + self.num_embeddings_padded = pad_vocab_size(num_embeddings, self.padding_size) if not self.column_cut: self.embeddings = fleet.meta_parallel.VocabParallelEmbedding( - num_embeddings, + self.num_embeddings_padded, embedding_dim, mp_group=self.tp_group, weight_attr=paddle.ParamAttr( @@ -76,7 +164,7 @@ def __init__( ), ) if self.world_size > 1: - set_weight_attrs(self.embeddings.weight, {"output_dim": False}) + set_weight_attrs(self.embeddings.weight, {"output_dim": False, "weight_loader": self.weight_loader}) else: # column cut embedding self.embeddings = nn.Embedding( @@ -106,6 +194,88 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): self.embeddings.weight.set_value(weight_tensor) + @classmethod + def _get_indices( + cls, + vocab_size_paded: int, + org_vocab_size_padded: int, + vocab_size: int, + org_vocab_size: int, + tp_rank: int, + tp_size: int, + ) -> VocabParallelEmbeddingShardIndices: + """Get start and end indices for vocab parallel embedding, following the + layout outlined in the class docstring, based on the given tp_rank and + tp_size.""" + + num_added_embeddings_padded = vocab_size_paded - org_vocab_size_padded + padded_org_vocab_start_index, padded_org_vocab_end_index = vocab_range_from_global_vocab_size( + org_vocab_size_padded, tp_rank, tp_size + ) + + padded_added_vocab_start_index, padded_added_vocab_end_index = vocab_range_from_global_vocab_size( + num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size + ) + # remove padding + org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size) + org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) + added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size) + added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) + return VocabParallelEmbeddingShardIndices( + padded_org_vocab_start_index, + padded_org_vocab_end_index, + padded_added_vocab_start_index, + padded_added_vocab_end_index, + org_vocab_start_index, + org_vocab_end_index, + added_vocab_start_index, + added_vocab_end_index, + ) + + def weight_loader(self, param, loaded_weight, shard_id=None): + output_dim = getattr(param, "output_dim", None) + packed_dim = getattr(param, "packed_dim", None) + + loaded_weight = get_tensor(loaded_weight) + if param.dtype != loaded_weight.dtype: + if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn: + loaded_weight = loaded_weight.cast(param.dtype) + else: + loaded_weight = loaded_weight.cast(param.dtype) + + if output_dim is None: + assert ( + param.shape == loaded_weight.shape + ), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}" + param.set_value(loaded_weight) + return + + start_idx = self.shard_indices.org_vocab_start_index + end_idx = self.shard_indices.org_vocab_end_index + shard_size = self.shard_indices.org_vocab_end_index - start_idx + + # If param packed on the same dim we are sharding on, then + # need to adjust offsets of loaded weight by pack_factor. + if packed_dim is not None and packed_dim == output_dim: + packed_factor = getattr(param, "packed_factor", getattr(param, "pack_factor", 1)) + assert loaded_weight.shape[output_dim] == (self.org_vocab_size // packed_factor) + start_idx = start_idx // packed_factor + shard_size = shard_size // packed_factor + else: + assert loaded_weight.shape[output_dim] == self.org_vocab_size, ( + f"Loaded weight dim {output_dim} size {loaded_weight.shape[output_dim]} " + f"!= org_vocab_size {self.org_vocab_size}" + ) + + shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx) + + if output_dim == 0: + param[: shard_weight.shape[0]].copy_(shard_weight, False) + param[shard_weight.shape[0] :].fill_(0) + else: + param[:, : shard_weight.shape[1]].copy_(shard_weight, False) + param[:, shard_weight.shape[1] :].fill_(0) + def forward(self, ids_remove_padding=None) -> paddle.Tensor: """ Defines the forward computation of the layer. diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index 57131b00a27..ff1bdaa9217 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -22,6 +22,10 @@ from paddle.distributed import fleet from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.utils import ( + DEFAULT_VOCAB_PADDING_SIZE, + pad_vocab_size, +) from fastdeploy.model_executor.utils import ( default_weight_loader, set_weight_attrs, @@ -44,6 +48,7 @@ def __init__( prefix: str = "", with_bias: bool = False, dtype: str = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, ) -> None: """ Parallelized LMhead. @@ -68,6 +73,10 @@ def __init__( self.column_cut = True self.nranks = fd_config.parallel_config.tensor_parallel_size self.fd_config = fd_config + self.padding_size = padding_size + + if num_embeddings % self.nranks != 0: + num_embeddings = pad_vocab_size(num_embeddings, self.padding_size) ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index 85de8ec4c14..27bc770e88d 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -45,6 +45,14 @@ c8_state_dict = paddle.load(cache_params, return_numpy=True) +DEFAULT_VOCAB_PADDING_SIZE = 64 + + +def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: + """Pad the vocab size to the given value.""" + return ((vocab_size + pad_to - 1) // pad_to) * pad_to + + def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]: """ Only used in deep_gemm block wise quant weight. @@ -372,3 +380,14 @@ def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str]) paddle.Tensor: An empty tensor with the specified shape and data type. """ return paddle.empty(list(shape), dtype=dtype) + + +def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, rank: int, offset: int = 0): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f + offset, index_l + offset + + +def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int, offset: int = 0): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, offset=offset) diff --git a/tests/pooling/test_embedding.py b/tests/pooling/test_embedding.py index d609726e235..a548494dcad 100644 --- a/tests/pooling/test_embedding.py +++ b/tests/pooling/test_embedding.py @@ -27,7 +27,9 @@ ModelConfig, ParallelConfig, ) +from fastdeploy.model_executor.models.adapters import as_embedding_model from fastdeploy.model_executor.models.model_base import ModelRegistry +from fastdeploy.scheduler import SchedulerConfig current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.abspath(os.path.join(current_dir, "..")) @@ -36,58 +38,103 @@ from tests.model_loader.utils import get_torch_model_path +test_model_configs = { + "Qwen3-0.6B": { + "tensor_parallel_size": 2, + "max_model_len": 8192, + "baseline_suffix": "standard", + }, + "Qwen3-Embedding-0.6B": { + "tensor_parallel_size": 2, + "max_model_len": 8192, + "baseline_suffix": "embedding", + }, +} + class TestModelLoader: @pytest.fixture(scope="session", autouse=True) def setup_paddle(self): if not paddle.is_compiled_with_cuda(): - print("CUDA not available, using CPU") - paddle.set_device("cpu") - else: - print("Using CUDA device") - paddle.set_device("gpu") + raise AssertionError("CUDA not available") + paddle.set_device("gpu") yield - @pytest.fixture(scope="session") - def model_path(self): + @pytest.fixture(scope="session", params=list(test_model_configs.keys())) + def model_info(self, request): + model_name = request.param try: - torch_model_path = get_torch_model_path("Qwen3-0.6B") - if os.path.exists(torch_model_path): - return torch_model_path + torch_model_path = get_torch_model_path(model_name) + if not os.path.exists(torch_model_path): + raise AssertionError(f"Model path does not exist: {torch_model_path}") + return {"name": model_name, "path": torch_model_path, "config": test_model_configs[model_name]} except Exception as e: - print(f"Could not get torch model path: {e}") + raise AssertionError(f"Could not get torch model path for {model_name}: {e}") @pytest.fixture - def model_config(self, model_path): + def model_config(self, model_info): + if model_info is None: + raise AssertionError("model_info is None") + model_args = { - "model": model_path, + "model": model_info["path"], "dtype": "bfloat16", - "max_model_len": 8192, - "tensor_parallel_size": 1, + "max_model_len": model_info["config"]["max_model_len"], + "tensor_parallel_size": model_info["config"]["tensor_parallel_size"], "runner": "auto", "convert": "auto", } try: - return ModelConfig(model_args) + config = ModelConfig(model_args) + return config + except Exception as e: + raise AssertionError(f"Could not create ModelConfig: {e}") + + @pytest.fixture + def scheduler_config(self): + scheduler_args = { + "name": "local", + "max_num_seqs": 256, + "max_num_batched_tokens": 8192, + "splitwise_role": "mixed", + "max_size": -1, + "ttl": 900, + "max_model_len": 8192, + "enable_chunked_prefill": False, + "max_num_partial_prefills": 1, + "max_long_partial_prefills": 1, + "long_prefill_token_threshold": 0, + } + + try: + config = SchedulerConfig(scheduler_args) + return config except Exception as e: - print(f"Could not create ModelConfig: {e}") + raise AssertionError(f"Could not create SchedulerConfig: {e}") @pytest.fixture - def fd_config(self, model_config): + def fd_config(self, model_info, model_config, scheduler_config): + if model_config is None: + raise AssertionError("ModelConfig is None") + if scheduler_config is None: + raise AssertionError("SchedulerConfig is None") + try: + tensor_parallel_size = model_info["config"]["tensor_parallel_size"] + cache_args = { "block_size": 64, "gpu_memory_utilization": 0.9, "cache_dtype": "bfloat16", "model_cfg": model_config, - "tensor_parallel_size": 1, + "tensor_parallel_size": tensor_parallel_size, } cache_config = CacheConfig(cache_args) parallel_args = { - "tensor_parallel_size": 1, + "tensor_parallel_size": tensor_parallel_size, "data_parallel_size": 1, } parallel_config = ParallelConfig(parallel_args) @@ -95,88 +142,80 @@ def fd_config(self, model_config): load_args = {} load_config = LoadConfig(load_args) - graph_opt_args = { - "enable_cudagraph": False, - "cudagraph_capture_sizes": None, - } + graph_opt_args = {} graph_opt_config = GraphOptimizationConfig(graph_opt_args) - return FDConfig( + fd_config = FDConfig( model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, + scheduler_config=scheduler_config, load_config=load_config, graph_opt_config=graph_opt_config, test_mode=True, ) + return fd_config + except Exception as e: - print(f"Could not create FDConfig: {e}") + raise AssertionError(f"Could not create FDConfig: {e}") @pytest.fixture - def model_json_config(self, model_path): - config_path = os.path.join(model_path, "config.json") - if os.path.exists(config_path): - with open(config_path, "r", encoding="utf-8") as f: - return json.load(f) - return None + def model_json_config(self, model_info): + if model_info is None: + raise AssertionError("model_info is None") - def test_embedding_with_none_convert_type(self, fd_config, model_json_config): - if model_json_config is None: - pytest.skip("Model config not available") + config_path = os.path.join(model_info["path"], "config.json") + if not os.path.exists(config_path): + raise AssertionError(f"Config file does not exist: {config_path}") - if fd_config is None: - pytest.skip("FDConfig not available") + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) - print("=" * 60) - print("Testing initialize_model with convert_type='none'") - print("=" * 60) + def test_embedding_with_none_convert_type(self, model_info, fd_config, model_json_config): + if any(x is None for x in [model_info, fd_config, model_json_config]): + raise AssertionError("Required configs not available") architectures = model_json_config.get("architectures", []) if not architectures: - pytest.skip("No architectures found in model config") + raise AssertionError("No architectures found in model config") fd_config.model_config.convert_type = "none" try: - model_cls = ModelRegistry.get_class(architectures) + model_cls = ModelRegistry.get_class(architectures[0]) if hasattr(model_cls, "__name__"): assert ( "ForEmbedding" not in model_cls.__name__ ), f"Standard model should not have 'ForEmbedding' in name, but got: {model_cls.__name__}" - print(f"Confirmed standard model type (no ForEmbedding): {model_cls.__name__}") standard_methods = set(dir(model_cls)) assert "_init_pooler" not in standard_methods, "Standard model should not have _init_pooler method" except Exception as e: - print(f"Error in none: {e}") + raise AssertionError(f"Error in none convert type test: {e}") - def test_embedding_with_embed_convert_type(self, fd_config, model_json_config): - if model_json_config is None: - pytest.skip("Model config not available") - - if fd_config is None: - pytest.skip("FDConfig not available") - - print("=" * 60) - print("Testing embedding with convert_type='embed'") - print("=" * 60) + def test_embedding_with_embed_convert_type(self, model_info, fd_config, model_json_config): + if any(x is None for x in [model_info, fd_config, model_json_config]): + raise AssertionError("Required configs not available") architectures = model_json_config.get("architectures", []) if not architectures: - pytest.skip("No architectures found in model config") + raise AssertionError("No architectures found in model config") fd_config.model_config.convert_type = "embed" try: - model_cls = ModelRegistry.get_class(architectures) + model_cls = ModelRegistry.get_class(architectures[0]) + model_cls = as_embedding_model(model_cls) + if hasattr(model_cls, "__name__"): - assert "ForEmbedding" in model_cls.__name__, "Embedding model should have 'ForEmbedding' in name" - print(f"Confirmed embedding model type: {model_cls.__name__}") + assert ( + "ForEmbedding" in model_cls.__name__ + ), f"Embedding model should have 'ForEmbedding' in name, but got: {model_cls.__name__}" embedding_methods = set(dir(model_cls)) assert "_init_pooler" in embedding_methods, "Embedding model should have _init_pooler method" except Exception as e: - print(f"Error in convert embed: {e}") + raise AssertionError(f"Error in embed convert type test: {e}") From 7e7867923c474dc4725dbe4cda292d34e5bd67fb Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Tue, 23 Sep 2025 11:35:33 +0800 Subject: [PATCH 02/10] fix ci bug --- tests/pooling/test_embedding.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/pooling/test_embedding.py b/tests/pooling/test_embedding.py index a548494dcad..23190fdb51f 100644 --- a/tests/pooling/test_embedding.py +++ b/tests/pooling/test_embedding.py @@ -19,6 +19,11 @@ import paddle import pytest +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) + from fastdeploy.config import ( CacheConfig, FDConfig, From deffa1832079d7bea5c53bf97e61c3b9a54384de Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Tue, 23 Sep 2025 11:46:16 +0800 Subject: [PATCH 03/10] fix --- tests/pooling/test_embedding.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/pooling/test_embedding.py b/tests/pooling/test_embedding.py index 23190fdb51f..a548494dcad 100644 --- a/tests/pooling/test_embedding.py +++ b/tests/pooling/test_embedding.py @@ -19,11 +19,6 @@ import paddle import pytest -current_dir = os.path.dirname(os.path.abspath(__file__)) -project_root = os.path.abspath(os.path.join(current_dir, "..")) -if project_root not in sys.path: - sys.path.insert(0, project_root) - from fastdeploy.config import ( CacheConfig, FDConfig, From 72fae93a37f95489b46585ede73d947fcb1918b1 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Tue, 23 Sep 2025 11:50:56 +0800 Subject: [PATCH 04/10] fix ci bug --- fastdeploy/model_executor/models/qwen3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 47ed104babf..50fdb7222a3 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -303,6 +303,7 @@ def load_weights(self, weights_iterator) -> None: if model_param_name not in params_dict: continue param = params_dict[model_param_name] + print("params_dict", params_dict) weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight, shard_id) @@ -312,6 +313,7 @@ def load_weights(self, weights_iterator) -> None: if model_param_name not in params_dict: continue param = params_dict[model_param_name] + print("params_dict", params_dict) weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight) From 549c033034f4f0a03d9d7e52141035749543e2be Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Tue, 23 Sep 2025 11:51:44 +0800 Subject: [PATCH 05/10] fix ci bug --- fastdeploy/model_executor/layers/pool/__init__.py | 15 +++++++++++++++ fastdeploy/model_executor/models/qwen3.py | 1 - 2 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 fastdeploy/model_executor/layers/pool/__init__.py diff --git a/fastdeploy/model_executor/layers/pool/__init__.py b/fastdeploy/model_executor/layers/pool/__init__.py new file mode 100644 index 00000000000..f4ede90624a --- /dev/null +++ b/fastdeploy/model_executor/layers/pool/__init__.py @@ -0,0 +1,15 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 50fdb7222a3..17c900e4b3e 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -303,7 +303,6 @@ def load_weights(self, weights_iterator) -> None: if model_param_name not in params_dict: continue param = params_dict[model_param_name] - print("params_dict", params_dict) weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight, shard_id) From 9a17cbdada514adf5a0064e196e845c1bd098126 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Tue, 23 Sep 2025 11:53:04 +0800 Subject: [PATCH 06/10] fix --- fastdeploy/model_executor/models/qwen3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 17c900e4b3e..47ed104babf 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -312,7 +312,6 @@ def load_weights(self, weights_iterator) -> None: if model_param_name not in params_dict: continue param = params_dict[model_param_name] - print("params_dict", params_dict) weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight) From 345cd7a2d4d41ae58e0df8beffc694ba2c56be27 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Tue, 23 Sep 2025 17:26:49 +0800 Subject: [PATCH 07/10] fix qwen3-embedding --- .../model_executor/layers/embeddings.py | 27 ++++++++++++------- fastdeploy/model_executor/models/adapters.py | 3 ++- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 7b97c53e5b1..3a0c6fc40d3 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -238,10 +238,7 @@ def weight_loader(self, param, loaded_weight, shard_id=None): loaded_weight = get_tensor(loaded_weight) if param.dtype != loaded_weight.dtype: - if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn: - loaded_weight = loaded_weight.cast(param.dtype) - else: - loaded_weight = loaded_weight.cast(param.dtype) + loaded_weight = loaded_weight.cast(param.dtype) if output_dim is None: assert ( @@ -269,12 +266,24 @@ def weight_loader(self, param, loaded_weight, shard_id=None): shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx) - if output_dim == 0: - param[: shard_weight.shape[0]].copy_(shard_weight, False) - param[shard_weight.shape[0] :].fill_(0) + total_padded_size = self.shard_indices.num_elements_padded + actual_size = shard_weight.shape[output_dim] + padding_size = total_padded_size - actual_size + + if padding_size > 0: + padding_shape = list(shard_weight.shape) + padding_shape[output_dim] = padding_size + padding_tensor = paddle.zeros(padding_shape, dtype=shard_weight.dtype) + + final_weight = paddle.concat([shard_weight, padding_tensor], axis=output_dim) else: - param[:, : shard_weight.shape[1]].copy_(shard_weight, False) - param[:, shard_weight.shape[1] :].fill_(0) + final_weight = shard_weight + + assert ( + final_weight.shape == param.shape + ), f"Final weight shape {final_weight.shape} doesn't match param shape {param.shape}" + + param.copy_(final_weight, False) def forward(self, ids_remove_padding=None) -> paddle.Tensor: """ diff --git a/fastdeploy/model_executor/models/adapters.py b/fastdeploy/model_executor/models/adapters.py index d56c1dcb1f4..1f2590acdd3 100644 --- a/fastdeploy/model_executor/models/adapters.py +++ b/fastdeploy/model_executor/models/adapters.py @@ -22,7 +22,6 @@ from fastdeploy.config import ModelConfig from fastdeploy.model_executor.layers.activation import get_act_fn -from fastdeploy.model_executor.models.interfaces_base import is_pooling_model from fastdeploy.transformer_utils.config import get_hf_file_to_dict _T = TypeVar("_T", bound=type[nn.Layer]) @@ -191,6 +190,8 @@ def as_embedding_model(cls: _T) -> _T: please implement your own model if this is not the case. """ # Avoid modifying existing embedding models + from fastdeploy.model_executor.models.interfaces_base import is_pooling_model + if is_pooling_model(cls): return cls From 1fde23caaa65a2a62c3496b310bbe2c81c9e4e33 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Tue, 23 Sep 2025 18:25:31 +0800 Subject: [PATCH 08/10] fix --- .../model_executor/layers/embeddings.py | 64 ++++++++++--------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 3a0c6fc40d3..4d67b660eb1 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -237,53 +237,55 @@ def weight_loader(self, param, loaded_weight, shard_id=None): packed_dim = getattr(param, "packed_dim", None) loaded_weight = get_tensor(loaded_weight) + + print("param", param) + if not hasattr(param, "_initialized") or not param._initialized: + try: + temp_data = paddle.zeros(param.shape, dtype=param.dtype) + param.copy_(temp_data, False) + except: + # 如果copy_失败,使用set_value + param.set_value(paddle.zeros(param.shape, dtype=param.dtype)) + if param.dtype != loaded_weight.dtype: loaded_weight = loaded_weight.cast(param.dtype) if output_dim is None: - assert ( - param.shape == loaded_weight.shape - ), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}" - param.set_value(loaded_weight) + assert param.shape == loaded_weight.shape + param.copy_(loaded_weight, False) return + # 获取分片索引 start_idx = self.shard_indices.org_vocab_start_index - end_idx = self.shard_indices.org_vocab_end_index shard_size = self.shard_indices.org_vocab_end_index - start_idx - # If param packed on the same dim we are sharding on, then - # need to adjust offsets of loaded weight by pack_factor. if packed_dim is not None and packed_dim == output_dim: packed_factor = getattr(param, "packed_factor", getattr(param, "pack_factor", 1)) assert loaded_weight.shape[output_dim] == (self.org_vocab_size // packed_factor) start_idx = start_idx // packed_factor shard_size = shard_size // packed_factor else: - assert loaded_weight.shape[output_dim] == self.org_vocab_size, ( - f"Loaded weight dim {output_dim} size {loaded_weight.shape[output_dim]} " - f"!= org_vocab_size {self.org_vocab_size}" - ) - - shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx) - - total_padded_size = self.shard_indices.num_elements_padded - actual_size = shard_weight.shape[output_dim] - padding_size = total_padded_size - actual_size - - if padding_size > 0: - padding_shape = list(shard_weight.shape) - padding_shape[output_dim] = padding_size - padding_tensor = paddle.zeros(padding_shape, dtype=shard_weight.dtype) - - final_weight = paddle.concat([shard_weight, padding_tensor], axis=output_dim) + assert loaded_weight.shape[output_dim] == self.org_vocab_size + + shard_weight = slice_fn(loaded_weight, output_dim, start_idx, start_idx + shard_size) + + # 参考vLLM的处理方式:直接对参数的前N个元素进行复制 + # 关键:确保不会访问超出边界的内存 + copy_size = min(shard_weight.shape[0], param.shape[0]) + + # 创建临时张量来存储完整的参数数据 + if output_dim == 0: + # 创建与param同样大小的新张量 + new_param_data = paddle.zeros_like(param) + # 将shard_weight复制到新张量的前面部分 + new_param_data[:copy_size].copy_(shard_weight[:copy_size], False) + # 整体替换参数 + param.copy_(new_param_data, False) else: - final_weight = shard_weight - - assert ( - final_weight.shape == param.shape - ), f"Final weight shape {final_weight.shape} doesn't match param shape {param.shape}" - - param.copy_(final_weight, False) + new_param_data = paddle.zeros_like(param) + copy_size = min(shard_weight.shape[1], param.shape[1]) + new_param_data[:, :copy_size].copy_(shard_weight[:, :copy_size], False) + param.copy_(new_param_data, False) def forward(self, ids_remove_padding=None) -> paddle.Tensor: """ From 03c37b0c322385f0e44861bc270e1089094473dc Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Tue, 23 Sep 2025 18:27:10 +0800 Subject: [PATCH 09/10] fix --- .../model_executor/layers/embeddings.py | 52 ++++++++----------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 4d67b660eb1..b35304f5afc 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -236,56 +236,48 @@ def weight_loader(self, param, loaded_weight, shard_id=None): output_dim = getattr(param, "output_dim", None) packed_dim = getattr(param, "packed_dim", None) - loaded_weight = get_tensor(loaded_weight) - - print("param", param) - if not hasattr(param, "_initialized") or not param._initialized: - try: - temp_data = paddle.zeros(param.shape, dtype=param.dtype) - param.copy_(temp_data, False) - except: - # 如果copy_失败,使用set_value - param.set_value(paddle.zeros(param.shape, dtype=param.dtype)) + if not param._is_initialized(): + param.initialize() + loaded_weight = get_tensor(loaded_weight) if param.dtype != loaded_weight.dtype: - loaded_weight = loaded_weight.cast(param.dtype) + if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn: + loaded_weight = loaded_weight.cast(param.dtype) + else: + loaded_weight = loaded_weight.cast(param.dtype) if output_dim is None: - assert param.shape == loaded_weight.shape + assert ( + param.shape == loaded_weight.shape + ), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}" param.copy_(loaded_weight, False) return - # 获取分片索引 start_idx = self.shard_indices.org_vocab_start_index + end_idx = self.shard_indices.org_vocab_end_index shard_size = self.shard_indices.org_vocab_end_index - start_idx + # If param packed on the same dim we are sharding on, then + # need to adjust offsets of loaded weight by pack_factor. if packed_dim is not None and packed_dim == output_dim: packed_factor = getattr(param, "packed_factor", getattr(param, "pack_factor", 1)) assert loaded_weight.shape[output_dim] == (self.org_vocab_size // packed_factor) start_idx = start_idx // packed_factor shard_size = shard_size // packed_factor else: - assert loaded_weight.shape[output_dim] == self.org_vocab_size - - shard_weight = slice_fn(loaded_weight, output_dim, start_idx, start_idx + shard_size) + assert loaded_weight.shape[output_dim] == self.org_vocab_size, ( + f"Loaded weight dim {output_dim} size {loaded_weight.shape[output_dim]} " + f"!= org_vocab_size {self.org_vocab_size}" + ) - # 参考vLLM的处理方式:直接对参数的前N个元素进行复制 - # 关键:确保不会访问超出边界的内存 - copy_size = min(shard_weight.shape[0], param.shape[0]) + shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx) - # 创建临时张量来存储完整的参数数据 if output_dim == 0: - # 创建与param同样大小的新张量 - new_param_data = paddle.zeros_like(param) - # 将shard_weight复制到新张量的前面部分 - new_param_data[:copy_size].copy_(shard_weight[:copy_size], False) - # 整体替换参数 - param.copy_(new_param_data, False) + param[: shard_weight.shape[0]].copy_(shard_weight, False) + param[shard_weight.shape[0] :].fill_(0) else: - new_param_data = paddle.zeros_like(param) - copy_size = min(shard_weight.shape[1], param.shape[1]) - new_param_data[:, :copy_size].copy_(shard_weight[:, :copy_size], False) - param.copy_(new_param_data, False) + param[:, : shard_weight.shape[1]].copy_(shard_weight, False) + param[:, shard_weight.shape[1] :].fill_(0) def forward(self, ids_remove_padding=None) -> paddle.Tensor: """ From 660133ff123f83e19d822adf51eea7687520efb5 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Tue, 23 Sep 2025 19:53:40 +0800 Subject: [PATCH 10/10] fix --- fastdeploy/model_executor/layers/embeddings.py | 4 +++- fastdeploy/worker/gpu_model_runner.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index b35304f5afc..6df196f6547 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -164,7 +164,9 @@ def __init__( ), ) if self.world_size > 1: - set_weight_attrs(self.embeddings.weight, {"output_dim": False, "weight_loader": self.weight_loader}) + set_weight_attrs(self.embeddings.weight, {"output_dim": False}) + if num_embeddings % self.world_size != 0: + set_weight_attrs(self.embeddings.weight, {"weight_loader", self.weight_loader}) else: # column cut embedding self.embeddings = nn.Embedding( diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 3817b49326f..c776c03eadc 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1321,6 +1321,7 @@ def _dummy_run( logits = None if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model: + # TODO(lizexu123) The preheating the pooling function have not been implemented yet. pass else: # 4. Execute spec decode @@ -1632,9 +1633,9 @@ class at the server level, which is too granular for ModelRunner. logits = None # 4. Compute logits, Sample if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model: + # TODO(lizexu123) The execution of the pooling function have not been implemented yet. pass else: - # 4. Execute spec decode logits = self.model.compute_logits(hidden_states) if not self.speculative_decoding: