Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,31 @@ def print(self):
logger.info("=============================================================")


class RoutingReplayConfig:
"""Configuration for Routing Replay used in RL training"""

def __init__(self, args) -> None:
self.enable_routing_replay: bool = False
self.routing_store_type: str = "local"

# Local routing store
self.local_store_dir: str = "./routing_replay_output"

# RDMA routing store
pass
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pass statement here serves no purpose and should be removed. If this is a placeholder for future RDMA configuration attributes, consider adding a comment instead:

# RDMA routing store
# TODO: Add RDMA-specific configuration parameters
Suggested change
pass
# TODO: Add RDMA-specific configuration parameters here

Copilot uses AI. Check for mistakes.

if args is not None:
for key, value in args.items():
if hasattr(self, key) and value != "None":
setattr(self, key, value)

def to_json_string(self):
"""
Convert routing replay config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items()})


class FDConfig:
"""
The configuration class which contains all fastdeploy-related configuration. This
Expand Down Expand Up @@ -1206,6 +1231,7 @@ def __init__(
test_mode=False,
enable_attention_dp_balance: bool = False,
attention_dp_time_out_iters: int = 0,
routing_replay_config: Optional[RoutingReplayConfig] = None,
):
self.model_config: ModelConfig = model_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore
Expand All @@ -1221,8 +1247,10 @@ def __init__(
self.cache_config: CacheConfig = cache_config # type: ignore
self.eplb_config: Optional[EPLBConfig] = eplb_config
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config
self.routing_replay_config = routing_replay_config
self.enable_attention_dp_balance = enable_attention_dp_balance
self.attention_dp_time_out_iters = attention_dp_time_out_iters

# Initialize cuda graph capture list
max_capture_shape = self.parallel_config.max_num_seqs
if self.speculative_config is not None and self.speculative_config.method == "mtp":
Expand Down
22 changes: 22 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MobaAttentionConfig,
ModelConfig,
ParallelConfig,
RoutingReplayConfig,
SpeculativeConfig,
TaskOption,
)
Expand Down Expand Up @@ -421,6 +422,11 @@ class EngineArgs:
Configuration for eplb.
"""

routing_replay_config: Optional[Dict[str, Any]] = None
"""
Flag to rollout routing replay(r3)
"""

def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
Expand Down Expand Up @@ -733,6 +739,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.eplb_config,
help="Config of eplb.",
)
parallel_group.add_argument(
"--routing-replay-config",
type=json.loads,
default=EngineArgs.routing_replay_config,
help="Flag of rollout routing replay(r3).",
)

# Load group
load_group = parser.add_argument_group("Load Configuration")
Expand Down Expand Up @@ -1078,6 +1090,14 @@ def create_eplb_config(self) -> EPLBConfig:
eplb_args[k] = v
return EPLBConfig(eplb_args)

def create_routing_repaly_config(self) -> RoutingReplayConfig:
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in method name: "repaly" should be "replay". This should be create_routing_replay_config to match the naming convention used elsewhere (e.g., routing_replay_config attribute).

Suggested change
def create_routing_repaly_config(self) -> RoutingReplayConfig:
def create_routing_replay_config(self) -> RoutingReplayConfig:

Copilot uses AI. Check for mistakes.
""" """
routing_replay_args = asdict(self)
if self.routing_replay_config is not None:
for k, v in self.routing_replay_config.items():
routing_replay_args[k] = v
return RoutingReplayConfig(routing_replay_args)

def create_engine_config(self, port_availability_check: bool = True) -> FDConfig:
"""
Create and return a Config object based on the current settings.
Expand Down Expand Up @@ -1120,6 +1140,7 @@ def create_engine_config(self, port_availability_check: bool = True) -> FDConfig
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
moba_attention_config = self.create_moba_attention_config()
eplb_cfg = self.create_eplb_config()
routing_replay_config = self.create_routing_repaly_config()

early_stop_cfg = self.create_early_stop_config()
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
Expand Down Expand Up @@ -1165,4 +1186,5 @@ def create_engine_config(self, port_availability_check: bool = True) -> FDConfig
early_stop_config=early_stop_cfg,
enable_attention_dp_balance=self.enable_attention_dp_balance,
attention_dp_time_out_iters=self.attention_dp_time_out_iters,
routing_replay_config=routing_replay_config,
)
1 change: 1 addition & 0 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def _start_worker_service(self):
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
f" --attention_dp_time_out_iters {self.cfg.attention_dp_time_out_iters}"
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
f" --routing_replay_config '{self.cfg.routing_replay_config.to_json_string()}'"
f" --ips {ips}"
)

Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class ForwardMeta:
block_tables: Optional[paddle.Tensor] = None
# KV caches
caches: Optional[list[paddle.Tensor]] = None
# Routing Replay table buffer
routing_replay_table: Optional[paddle.Tensor] = None

def clear_caches(self):
"""Safely clean up the caches"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
"""

from typing import Callable

import paddle
from paddle import nn

Expand Down Expand Up @@ -102,6 +104,7 @@ def apply(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
Expand All @@ -119,6 +122,9 @@ def apply(
topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False)
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)

if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)

intermediate_cache1 = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import multiprocessing
import os
from typing import Callable

import numpy as np
import paddle
Expand Down Expand Up @@ -189,6 +190,7 @@ def apply(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.
Expand All @@ -201,6 +203,7 @@ def apply_ep_prefill(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
Expand All @@ -212,6 +215,7 @@ def apply_ep_decode(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
Expand All @@ -223,6 +227,7 @@ def apply_tp(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
Expand Down Expand Up @@ -388,6 +393,7 @@ def apply(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
"""

from typing import Callable

import paddle
from paddle import nn

Expand Down Expand Up @@ -132,6 +134,7 @@ def apply(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
Expand All @@ -151,6 +154,10 @@ def apply(
True, # apply_norm_weight,
False,
)

if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)

up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
Expand Down
10 changes: 7 additions & 3 deletions fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

from abc import abstractmethod
from typing import Callable

import paddle
from paddle import nn
Expand Down Expand Up @@ -120,6 +121,7 @@ def apply_ep_prefill(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
Expand All @@ -144,6 +146,7 @@ def apply_tp(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
Expand All @@ -155,6 +158,7 @@ def apply(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
Expand All @@ -163,13 +167,13 @@ def apply(
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
if layer.fd_config.parallel_config.splitwise_role == "mixed" and layer.layer_idx == 0:
self.ep_prefill_runner.clean_low_latency_buffer()
return self.apply_ep_prefill(layer, x, gate)
return self.apply_ep_prefill(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
else:
if layer.fd_config.parallel_config.splitwise_role == "mixed" and layer.layer_idx == 0:
self.ep_decoder_runner.clean_low_latency_buffer()
return self.apply_ep_decode(layer, x, gate)
return self.apply_ep_decode(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
else:
return self.apply_tp(layer, x, gate)
return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)


class UnquantizedFusedMoEMethod(MoEMethodBase):
Expand Down
16 changes: 16 additions & 0 deletions fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
"""

from typing import Callable

import paddle
from paddle import nn
from paddle.nn.quant import weight_quantize
Expand Down Expand Up @@ -105,6 +107,7 @@ def apply_ep_prefill(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
Expand All @@ -121,6 +124,10 @@ def apply_ep_prefill(
handle,
_,
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)

if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)

token_all_num = sum(recv_num_tokens_per_expert_list)

# 3. Compute ffn
Expand Down Expand Up @@ -178,6 +185,7 @@ def apply_ep_decode(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
Expand All @@ -186,6 +194,10 @@ def apply_ep_decode(
estimate_total_token_nums = gate_out.shape[0] * layer.top_k
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)

if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)

expertwise_scale = None
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
Expand Down Expand Up @@ -220,6 +232,7 @@ def apply_tp(
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
Expand Down Expand Up @@ -277,6 +290,9 @@ def apply_tp(
topk_only_mode=False,
)

if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)

if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 need expert_idx_per_token
# Other need not this tensor, so we make it None.
Expand Down
Loading