Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions fastdeploy/engine/pooling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class PoolingParams(
normalize: Whether to normalize the embeddings outputs.
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
activation: Whether to apply activation function to
the classification outputs.
softmax: Whether to apply softmax to the reward outputs.
step_tag_id: Step tag ID for process reward models to identify
specific steps in multi-step reasoning tasks.
Expand Down
10 changes: 4 additions & 6 deletions fastdeploy/model_executor/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,8 @@ def __init__(
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
),
)
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
if num_embeddings % self.world_size != 0:
set_weight_attrs(self.embeddings.weight, {"weight_loader", self.weight_loader})
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
set_weight_attrs(self.embeddings.weight, {"weight_loader": self.weight_loader})
else:
# column cut embedding
self.embeddings = nn.Embedding(
Expand All @@ -176,8 +174,8 @@ def __init__(

self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": True})

set_weight_attrs(self.embeddings.weight, {"output_dim": True})

self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob)
Expand Down
5 changes: 3 additions & 2 deletions fastdeploy/model_executor/layers/pool/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def build_pooling_cursor(num_scheduled_tokens: list[int], prompt_lens: paddle.Te

n_seq = len(num_scheduled_tokens)
index = list(range(n_seq))
num_scheduled_tokens = paddle.to_tensor(num_scheduled_tokens, device="cpu")
cumsum = paddle.zeros([n_seq + 1], dtype="int64", place=paddle.CPUPlace())
num_scheduled_tokens = paddle.to_tensor(num_scheduled_tokens)
cumsum = paddle.zeros([n_seq + 1], dtype="int64")

paddle.cumsum(num_scheduled_tokens, axis=0, out=cumsum[1:])
if device == "gpu":
cumsum_device = cumsum.cuda()
Expand Down
23 changes: 23 additions & 0 deletions fastdeploy/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,29 @@ def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
return MeanPool()
raise NotImplementedError(f"Unsupported method: {pooling_type}")

@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
raise NotImplementedError

def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate()

@abstractmethod
def forward_all(
self,
hidden_states: paddle.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[paddle.Tensor], paddle.Tensor]:
raise NotImplementedError

def forward(
self,
hidden_states: paddle.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[paddle.Tensor], paddle.Tensor]:
pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor)


class LastPool(PoolingMethod):

Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:

def as_embedding_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support embeddings.
Subclass an existing FastDeploy model to support embeddings.

By default, the embeddings of the whole prompt are extracted from the
normalized hidden state corresponding to the last token.
Expand Down
64 changes: 56 additions & 8 deletions fastdeploy/model_executor/models/interfaces_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Type
from typing import ClassVar, Literal, Protocol, Type

import paddle
from paddle import nn
from typing_extensions import TypeVar, runtime_checkable

from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.pooler import Pooler

T = TypeVar("T", default=paddle.Tensor)
T_co = TypeVar("T_co", default=paddle.Tensor, covariant=True)


def is_text_generation_model(model_cls: Type[nn.Layer]) -> bool:
Expand All @@ -24,13 +33,7 @@ def is_text_generation_model(model_cls: Type[nn.Layer]) -> bool:


def is_pooling_model(model_cls: Type[nn.Layer]) -> bool:
class_name = model_cls.__name__
pooling_indicators = ["Embedding", "ForSequenceClassification"]
return (
any(indicator in class_name for indicator in pooling_indicators)
or hasattr(model_cls, "is_embedding_model")
and model_cls.is_embedding_model
)
return getattr(model_cls, "is_pooling_model", False)


def is_multimodal_model(class_name: str) -> bool:
Expand All @@ -52,3 +55,48 @@ def get_default_pooling_type(model_cls: Type[nn.Layer] = None) -> str:
if model_cls is not None:
return getattr(model_cls, "default_pooling_type", "LAST")
return "LAST"


@runtime_checkable
class FdModel(Protocol[T_co]):
"""The interface required for all models in FastDeploy."""
Comment on lines +61 to +62
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哪些类会继承FDModel,和 ModelForCasualLM 是啥关系

Copy link
Collaborator Author

@lizexu123 lizexu123 Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只有FDModelForPooling继承,和ModelForCasualLM没关系,ModelForCasualLM有compute_logits,pooling模型不计算这个


def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
) -> None:
pass

def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_metadata: ForwardMeta,
) -> T_co:
pass


class FdModelForPooling(FdModel[T_co], Protocol[T_co]):
"""The interface required for all pooling models in FastDeploy."""

is_pooling_model: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports pooling.

Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""

default_pooling_type: ClassVar[str] = "LAST"
"""
Indicates the
[fastdeploy.config.PoolerConfig.pooling_type][]
to use by default.

You can use the
[fastdeploy.model_executor.models.interfaces_base.default_pooling_type][]
decorator to conveniently set this field.
"""
pooler: Pooler
"""The pooler is only called on TP rank 0."""
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def load_weights(self, weights_iterator) -> None:
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]

weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))

weight_loader(param, loaded_weight, shard_id)

break
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/model_executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def free_tensor(tensor):
del tensor


def default_weight_loader(fd_config: FDConfig) -> None:
def default_weight_loader(fd_config: FDConfig = None) -> None:
"""Default weight loader"""

def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
Expand All @@ -169,7 +169,7 @@ def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1:
if output_dim is not None and fd_config is not None and fd_config.parallel_config.tensor_parallel_size > 1:
dim = -1 if output_dim else 0
if isinstance(loaded_weight, paddle.Tensor):
size = loaded_weight.shape[dim]
Expand Down
Loading
Loading