diff --git a/fastdeploy/eplb/async_expert_loader.py b/fastdeploy/eplb/async_expert_loader.py index cfffde97cb9..7ad3d5dab21 100644 --- a/fastdeploy/eplb/async_expert_loader.py +++ b/fastdeploy/eplb/async_expert_loader.py @@ -152,6 +152,9 @@ def load_tensor_from_shm_mem(tensor_infos, shm_ptr, logger=None): # NumPy 不支持 bfloat16,因此先以 uint16 读取原始数据,再用 Paddle cast 为 bfloat16 tmp = np_array.view(np.uint16) tensor = paddle.Tensor(tmp, dtype=paddle.bfloat16, place=paddle.CPUPlace(), zero_copy=True) + elif dtype == paddle.float8_e4m3fn: + tmp = np_array.view(np.uint8) + tensor = paddle.Tensor(tmp, dtype=paddle.float8_e4m3fn, place=paddle.CPUPlace(), zero_copy=True) else: raise TypeError(f"Unsupported dtype: {dtype}") @@ -294,8 +297,6 @@ def load_safetensor_fp8_from_disk(self, need_to_reload: List[Tuple[int, int]]): """ up_gate_down = ["up_gate_proj", "down_proj"] quant_weight_scale = ["quant_weight", "weight_scale"] - if self.moe_quant_type == "w4a8": - quant_weight_scale = ["quant_weight"] ckpt_name = [ (f"ernie.layers.{layer_id}.mlp.experts.{expert_id}.{proj_name}.{quant_name}") for layer_id, expert_id in need_to_reload @@ -312,7 +313,7 @@ def load_safetensor_fp8_from_disk(self, need_to_reload: List[Tuple[int, int]]): from safetensors import safe_open for st_file in hf_weights_files: - with safe_open(st_file, framework="np", device="cpu") as f: + with safe_open(st_file, framework="paddle", device="cpu") as f: for name in f.keys(): if name in ckpt_name: weight = f.get_tensor(name) diff --git a/fastdeploy/eplb/experts_manager.py b/fastdeploy/eplb/experts_manager.py index cad2e92be74..ccfc5445c36 100644 --- a/fastdeploy/eplb/experts_manager.py +++ b/fastdeploy/eplb/experts_manager.py @@ -21,7 +21,7 @@ class RedundantExpertManager: RedundantExpertManger """ - def __init__(self, rank=0, ep_size=32, fd_config=None): + def __init__(self, rank=0, ep_size=64, fd_config=None): self.logger = get_logger("eplb_expert_manager", "eplb_{0}.log".format(rank)) self.rank = rank @@ -101,7 +101,7 @@ def __init__(self, rank=0, ep_size=32, fd_config=None): self.http_timeout = 1 # 重置重排状态: 'done' -> 'free' self.rearrange_end_ts = 0 - self.rearrange_reset_interval = 300 + self.rearrange_reset_interval = 30 self.tensor_infos = None @@ -250,8 +250,8 @@ def caculate_expert_rank_table(self, is_init=False): eplb_strategy = self.eplb_config.redundant_expert_eplb_strategy if is_init: num_groups = 1 - num_nodes = 2 - num_gpus = 2 * 8 + num_nodes = 8 + num_gpus = 8 * 8 eplb_strategy = "" # eplb rank_expert_list, logical_to_physical_map, expert_count = rebalance_experts( @@ -420,7 +420,9 @@ def allreduce_load_weight_result(self): if not exist_fail and all_success: # prefill需要等待调度屏蔽 if ( - self.fd_config.splitwise_role == "decode" + self.fd_config.scheduler_config.splitwise_role == "mixed" + or self.fd_config.scheduler_config.splitwise_role == "decode" + or self.fd_config.scheduler_config.splitwise_role == "prefill" or not self.eplb_config.redundant_expert_enable_schedule_cordon ): self.logger.info("redundant_expert: allreduce_load_weight_result success, notify infer.py") diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 5736f6fda00..d00dd45a0b1 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -1153,6 +1153,10 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: # self.check(layer, up_gate_proj_weights, down_proj_weights) up_gate_proj_weight_scale = [] down_proj_weight_scale = [] + + if isinstance(state_dict, list): + state_dict = dict(state_dict) + for expert_idx in logical_expert_ids: up_gate_proj_weight_scale.append( get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 9910ac54c92..ba4fdb7ccd3 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -86,6 +86,10 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: # self.check(layer, up_gate_proj_weights, down_proj_weights) up_gate_proj_weight_scale = [] down_proj_weight_scale = [] + + if isinstance(state_dict, list): + state_dict = dict(state_dict) + for expert_idx in logical_expert_ids: up_gate_proj_expert_weight_scale_key_name = up_gate_proj_expert_weight_scale_key.format(expert_idx) down_proj_expert_weight_scale_key_name = down_proj_expert_weight_scale_key.format(expert_idx) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index b30db63a7ec..68a0d47cad3 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -260,6 +260,13 @@ def update_weights_from_tensor(self, mmap_infos): """ update_weights_from_tensor """ + import time + + while True: + if self.experts_manager.tensor_infos is None: + time.sleep(0.1) + else: + break state_dicts = load_tensor_from_shm_mem(self.experts_manager.tensor_infos, mmap_infos[MODEL_MAIN_NAME], logger) rank_expert_list, logical_to_physical_map, expert_count = self.experts_manager.get_ep_rank_to_expert_id_list() self.worker.get_model().redundant_table_manger.update_expert_rank_table( @@ -267,6 +274,7 @@ def update_weights_from_tensor(self, mmap_infos): ) # TO BE FIXED self.worker.get_model().update_state_dict(state_dicts) + self.experts_manager.tensor_infos = None def _broadcast_model_weights_signal(self, src: int, group) -> int: model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")