Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fastdeploy/engine/expert_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def start(

llm_logger.info(f"start expert service {local_data_parallel_id}")

if self.cfg.splitwise_role != "mixed" or self.cfg.cache_config.enable_prefix_caching:
if self.cfg.scheduler_config.splitwise_role != "mixed" or self.cfg.cache_config.enable_prefix_caching:
if self.do_profile:
get_profile_block_num = np.zeros([1], dtype=np.int32)
while True:
Expand All @@ -123,7 +123,7 @@ def start(
self.cache_manager_processes = self.engine.start_cache_service(
self.cfg.local_device_ids, ipc_signal_suffix_cache
)
if self.cfg.splitwise_role != "mixed":
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.engine.split_mode_get_tasks()

if self.cfg.scheduler_config.name == "splitwise":
Expand Down
86 changes: 47 additions & 39 deletions fastdeploy/model_executor/layers/backends/xpu/moe/ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,19 @@

import fastdeploy
from fastdeploy.config import MoEPhase
from fastdeploy.model_executor.layers.moe.ep import DeepEPEngineBase, EPRunner
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

公共部分修改太频繁。先不考虑代码复用问题,完全不依赖layers.moe.ep,等后续建立CI后再修改

from fastdeploy.utils import singleton


@singleton
class DeepEPEngine(DeepEPEngineBase):
class DeepEPEngine:
"""
A wrapper class for DeepEP engine.
"""

def __init__(
self,
num_max_dispatch_tokens_per_rank: int,
hidden: int,
hidden_size: int,
num_experts: int,
ep_size: int,
ep_rank: int,
Expand All @@ -52,20 +51,24 @@ def __init__(
ep_size: The number of ranks.
rank_id: The rank id.
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
hidden: The hidden dimension of the model.
hidden_size: The hidden_size dimension of the model.
num_experts: The number of experts.
"""
super().__init__(
num_max_dispatch_tokens_per_rank,
hidden,
num_experts,
ep_size,
ep_rank,
splitwise_role,
moe_phase,
async_finish,
group,
)
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
self.hidden_size = hidden_size
self.num_experts = num_experts
self.ep_size = ep_size
self.rank_id = ep_rank
self.splitwise_role = splitwise_role
self.moe_phase = moe_phase
self.async_finish = async_finish
# TODO(@wufeisheng): Support configurable EP size​
if group is None:
group = paddle.distributed.new_group(range(ep_size))
self.group = group
self.num_local_experts = num_experts // ep_size
self.deepep_engine = None
self.init_deepep_engine()

def init_deepep_engine(self):
if self.splitwise_role == "mixed" or self.moe_phase.phase == "prefill":
Expand All @@ -89,14 +92,14 @@ def get_low_latency_buffer(self):
Args:
group: The MPI group object.
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
hidden: The hidden dimension of the model.
hidden_size: The hidden_size dimension of the model.
"""
# NOTES: the low-latency mode will consume much more space than the normal mode
# So we recommend that `num_max_dispatch_tokens_per_rank`
# (the actual batch size in the decoding engine) should be less than 256
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
self.num_max_dispatch_tokens_per_rank,
self.hidden,
self.hidden_size,
self.ep_size,
self.num_experts,
)
Expand Down Expand Up @@ -127,12 +130,12 @@ def low_latency_dispatch(
):
"""
Args:
hidden_states: [token_num, hidden] 'bfloat16/int8'
hidden_states: [token_num, hidden_size] 'bfloat16/int8'
topk_idx: [token_num, num_topk] 'int64'

Returns:
recv_hidden_states: [num_local_experts,
num_max_dispatch_tokens_per_rank * ep_size, hidden]
num_max_dispatch_tokens_per_rank * ep_size, hidden_size]
ep_size * num_local_experts = num_experts
recv_count: [num_local_experts]
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
Expand Down Expand Up @@ -171,7 +174,7 @@ def low_latency_combine(
"""

Return:
combined_hidden_states: [num_tokens, hidden]
combined_hidden_states: [num_tokens, hidden_size]
"""
combined_hidden_states, combine_hook = self.deepep_engine.low_latency_combine(
hidden_states,
Expand All @@ -196,15 +199,15 @@ def barrier_all(self):
self.deepep_engine.barrier_all()


class XPUEPRunner(EPRunner):
class XPUEPRunner:
"""
EPRunnerBase
"""

def __init__(
self,
top_k: int,
hidden: int,
hidden_size: int,
num_experts: int,
splitwise_role: str,
moe_phase: MoEPhase,
Expand All @@ -214,23 +217,22 @@ def __init__(
redundant_experts_num: int = 0,
ep_group=None,
):
super().__init__(
top_k,
hidden,
num_experts,
splitwise_role,
moe_phase,
num_max_dispatch_tokens_per_rank,
ep_size,
ep_rank,
redundant_experts_num,
ep_group,
)
self.top_k = top_k
self.hidden_size = hidden_size
self.num_experts = num_experts
self.splitwise_role = splitwise_role
self.moe_phase = moe_phase
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
self.ep_size = ep_size
self.ep_rank = ep_rank
self.redundant_experts_num = redundant_experts_num
self.ep_group = ep_group
self.init_ep_engine()

def init_ep_engine(self):
self.ep_engine = DeepEPEngine(
num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank,
hidden=self.hidden,
hidden_size=self.hidden_size,
num_experts=self.num_experts + self.redundant_experts_num,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
Expand Down Expand Up @@ -285,6 +287,12 @@ def combine(self, *args, **kwargs):
"""
raise NotImplementedError

def clean_low_latency_buffer(self):
self.ep_engine.clean_low_latency_buffer()

def barrier_all(self):
self.ep_engine.barrier_all()


class XPUEPPrefillRunner(XPUEPRunner):
"""
Expand All @@ -294,7 +302,7 @@ class XPUEPPrefillRunner(XPUEPRunner):
def __init__(
self,
top_k: int,
hidden: int,
hidden_size: int,
num_experts: int,
splitwise_role: str,
num_max_dispatch_tokens_per_rank: int,
Expand All @@ -306,7 +314,7 @@ def __init__(
):
super().__init__(
top_k,
hidden,
hidden_size,
num_experts,
splitwise_role,
moe_phase,
Expand Down Expand Up @@ -358,7 +366,7 @@ class XPUEPDecoderRunner(XPUEPRunner):
def __init__(
self,
top_k: int,
hidden: int,
hidden_size: int,
num_experts: int,
splitwise_role: str,
num_max_dispatch_tokens_per_rank: int,
Expand All @@ -370,7 +378,7 @@ def __init__(
):
super().__init__(
top_k,
hidden,
hidden_size,
num_experts,
splitwise_role,
moe_phase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate // 2,
layer.moe_intermediate_size // 2,
]
else:
raise ValueError(f"Unsupported moe quant type: {self.moe_quant_type}")
Expand Down
19 changes: 14 additions & 5 deletions fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def __init__(self, quant_config):
self.ep_prefill_runner = None
self.ep_decoder_runner = None

def import_backend_ep_runner(self) -> None:
"""
Different platform has different ep runner. Override this method to import the corresponding EP runner.
"""
from .ep import EPDecoderRunner, EPPrefillRunner

self.EPPrefillRunner = EPPrefillRunner
self.EPDecoderRunner = EPDecoderRunner

def init_ep(self, layer: nn.Layer) -> None:
"""
Initialize EP (Expert Parallel) related modules.
Expand All @@ -51,7 +60,7 @@ def init_ep(self, layer: nn.Layer) -> None:
return

# Lazy import to avoid circular dependency or unnecessary loading
from .ep import EPDecoderRunner, EPPrefillRunner
self.import_backend_ep_runner()

# Common arguments for both runners
common_args = {
Expand All @@ -76,16 +85,16 @@ def init_ep(self, layer: nn.Layer) -> None:
# for RL init model without deepep buff
return
else:
self.ep_prefill_runner = EPPrefillRunner(**common_args)
self.ep_decoder_runner = EPDecoderRunner(**common_args)
self.ep_prefill_runner = self.EPPrefillRunner(**common_args)
self.ep_decoder_runner = self.EPDecoderRunner(**common_args)
return

# For non-mixed ep
phase = config.model_config.moe_phase.phase
if phase == "prefill":
self.ep_prefill_runner = EPPrefillRunner(**common_args)
self.ep_prefill_runner = self.EPPrefillRunner(**common_args)
else:
self.ep_decoder_runner = EPDecoderRunner(**common_args)
self.ep_decoder_runner = self.EPDecoderRunner(**common_args)

def process_loaded_weights(self, layer, weights) -> None:
"""
Expand Down
Loading