Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions fastdeploy/eplb/async_expert_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def load_tensor_from_shm_mem(tensor_infos, shm_ptr, logger=None):
# NumPy 不支持 bfloat16,因此先以 uint16 读取原始数据,再用 Paddle cast 为 bfloat16
tmp = np_array.view(np.uint16)
tensor = paddle.Tensor(tmp, dtype=paddle.bfloat16, place=paddle.CPUPlace(), zero_copy=True)
elif dtype == paddle.float8_e4m3fn:
tmp = np_array.view(np.uint8)
tensor = paddle.Tensor(tmp, dtype=paddle.float8_e4m3fn, place=paddle.CPUPlace(), zero_copy=True)
else:
raise TypeError(f"Unsupported dtype: {dtype}")

Expand Down Expand Up @@ -294,8 +297,6 @@ def load_safetensor_fp8_from_disk(self, need_to_reload: List[Tuple[int, int]]):
"""
up_gate_down = ["up_gate_proj", "down_proj"]
quant_weight_scale = ["quant_weight", "weight_scale"]
if self.moe_quant_type == "w4a8":
quant_weight_scale = ["quant_weight"]
ckpt_name = [
(f"ernie.layers.{layer_id}.mlp.experts.{expert_id}.{proj_name}.{quant_name}")
for layer_id, expert_id in need_to_reload
Expand All @@ -312,7 +313,7 @@ def load_safetensor_fp8_from_disk(self, need_to_reload: List[Tuple[int, int]]):
from safetensors import safe_open

for st_file in hf_weights_files:
with safe_open(st_file, framework="np", device="cpu") as f:
with safe_open(st_file, framework="paddle", device="cpu") as f:
for name in f.keys():
if name in ckpt_name:
weight = f.get_tensor(name)
Expand Down
12 changes: 7 additions & 5 deletions fastdeploy/eplb/experts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class RedundantExpertManager:
RedundantExpertManger
"""

def __init__(self, rank=0, ep_size=32, fd_config=None):
def __init__(self, rank=0, ep_size=64, fd_config=None):
self.logger = get_logger("eplb_expert_manager", "eplb_{0}.log".format(rank))

self.rank = rank
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(self, rank=0, ep_size=32, fd_config=None):
self.http_timeout = 1
# 重置重排状态: 'done' -> 'free'
self.rearrange_end_ts = 0
self.rearrange_reset_interval = 300
self.rearrange_reset_interval = 30

self.tensor_infos = None

Expand Down Expand Up @@ -250,8 +250,8 @@ def caculate_expert_rank_table(self, is_init=False):
eplb_strategy = self.eplb_config.redundant_expert_eplb_strategy
if is_init:
num_groups = 1
num_nodes = 2
num_gpus = 2 * 8
num_nodes = 8
num_gpus = 8 * 8
eplb_strategy = ""
# eplb
rank_expert_list, logical_to_physical_map, expert_count = rebalance_experts(
Expand Down Expand Up @@ -420,7 +420,9 @@ def allreduce_load_weight_result(self):
if not exist_fail and all_success:
# prefill需要等待调度屏蔽
if (
self.fd_config.splitwise_role == "decode"
self.fd_config.scheduler_config.splitwise_role == "mixed"
or self.fd_config.scheduler_config.splitwise_role == "decode"
or self.fd_config.scheduler_config.splitwise_role == "prefill"
or not self.eplb_config.redundant_expert_enable_schedule_cordon
):
self.logger.info("redundant_expert: allreduce_load_weight_result success, notify infer.py")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,10 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange:
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []

if isinstance(state_dict, list):
state_dict = dict(state_dict)

for expert_idx in logical_expert_ids:
up_gate_proj_weight_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange:
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []

if isinstance(state_dict, list):
state_dict = dict(state_dict)

for expert_idx in logical_expert_ids:
up_gate_proj_expert_weight_scale_key_name = up_gate_proj_expert_weight_scale_key.format(expert_idx)
down_proj_expert_weight_scale_key_name = down_proj_expert_weight_scale_key.format(expert_idx)
Expand Down
8 changes: 8 additions & 0 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,21 @@ def update_weights_from_tensor(self, mmap_infos):
"""
update_weights_from_tensor
"""
import time

while True:
if self.experts_manager.tensor_infos is None:
time.sleep(0.1)
else:
break
state_dicts = load_tensor_from_shm_mem(self.experts_manager.tensor_infos, mmap_infos[MODEL_MAIN_NAME], logger)
rank_expert_list, logical_to_physical_map, expert_count = self.experts_manager.get_ep_rank_to_expert_id_list()
self.worker.get_model().redundant_table_manger.update_expert_rank_table(
rank_expert_list, logical_to_physical_map, expert_count
)
# TO BE FIXED
self.worker.get_model().update_state_dict(state_dicts)
self.experts_manager.tensor_infos = None

def _broadcast_model_weights_signal(self, src: int, group) -> int:
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
Expand Down
Loading