Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,7 @@ def __post_init__(self):
raise NotImplementedError("Only CUDA platform supports logprob.")
if self.speculative_config is not None and self.logprobs_mode.startswith("processed"):
raise NotImplementedError("processed_logprobs not support in speculative.")
if self.speculative_config is not None:
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0

if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if not current_platform.is_cuda() and not current_platform.is_xpu():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __init__(
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
Expand Down Expand Up @@ -364,7 +366,7 @@ def forward_mixed(
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
forward_meta.attn_mask_offsets,
None if self.use_speculate else forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
Expand All @@ -383,7 +385,7 @@ def forward_mixed(
metadata.max_partition_size,
metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1,
self.causal,
self.causal or self.use_speculate,
self.speculative_method is not None,
sliding_window,
)
Expand Down
101 changes: 56 additions & 45 deletions fastdeploy/model_executor/layers/mtp_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from paddle import nn
from paddle.distributed import fleet

from fastdeploy.model_executor.utils import set_weight_attrs
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs

from .utils import get_tensor

Expand Down Expand Up @@ -53,44 +53,61 @@ def __init__(
self.bias_key = prefix + ".bias"
else:
self.bias_key = None
self.use_ep = fd_config.parallel_config.use_ep
self.fd_config = fd_config
self.tp_group = fd_config.parallel_config.tp_group
self.column_cut = True
self.nranks = fd_config.parallel_config.tensor_parallel_size

ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if self.use_ep:
self.weight = self.create_parameter(
shape=[embedding_dim, num_embeddings],
dtype=paddle.get_default_dtype(),
is_bias=False,
if self.column_cut:
need_gather = True
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
else:
if self.column_cut:
need_gather = True
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.bias_key is not None:
set_weight_attrs(
self.linear.bias,
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
if self.bias_key is not None:
set_weight_attrs(self.linear.bias, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": False})
else:
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
set_weight_attrs(
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
)

def load_state_dict(self, state_dict):
"""
Expand All @@ -100,17 +117,14 @@ def load_state_dict(self, state_dict):
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""

if self.use_ep:
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
else:
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.linear.weight.set_value(weight_tensor)
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.linear.weight.set_value(weight_tensor)

if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
self.linear.bias.set_value(bias)
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
self.linear.bias.set_value(bias)

def forward(self, input):
"""
Expand All @@ -123,8 +137,5 @@ def forward(self, input):
Tensor: The output tensor after processing through the layer.
"""
logits = input
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
else:
logits = self.linear(logits)
logits = self.linear(logits)
return logits
5 changes: 5 additions & 0 deletions fastdeploy/model_executor/model_loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer:
# register rl model
import fastdeploy.rl # noqa

if fd_config.speculative_config.model_type != "mtp":
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
else:
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")

architectures = architectures + "RL"
context = paddle.LazyGuard()
else:
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/model_executor/model_loader/default_loader_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer:
# register rl model
import fastdeploy.rl # noqa

if fd_config.speculative_config.model_type != "mtp":
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
else:
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")

architectures = architectures + "RL"

enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)
Expand Down
5 changes: 4 additions & 1 deletion fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False

def _compute_speculative_status(self):
# TODO(liuzichang): Supplement more statistics
interval = 10
interval = 1
if self.speculative_stats_step % interval == 0:
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
spec_logger.info(
Expand Down Expand Up @@ -593,6 +593,9 @@ def _process_batch_output(self):
+ accept_num[i]
].tolist()
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
self.resource_manager.reschedule_preempt_task(task_id)
continue
else:
token_id = int(tokens[i, 0])
Expand Down
22 changes: 13 additions & 9 deletions fastdeploy/rl/dynamic_weight_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
import os
import time
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Dict
from typing import Any, Dict, List

import numpy as np
import paddle
from paddle import nn
from paddleformers.utils.log import logger

from fastdeploy.config import FDConfig
Expand All @@ -31,7 +30,7 @@
class DynamicWeightManager:
"""Manages model weights loading, updating and shared state across processes."""

def __init__(self, fd_config: FDConfig, model: nn.Layer):
def __init__(self, fd_config: FDConfig, models):
"""Initialize with config and model instances."""
self.fd_config = fd_config
self.load_config = fd_config.load_config
Expand All @@ -42,7 +41,10 @@ def __init__(self, fd_config: FDConfig, model: nn.Layer):
self.meta_src_id = self._get_gpu_id()
self.first_load = True
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
self.model: nn.Layer = model
if not isinstance(models, List):
self.model_list = [models]
else:
self.model_list = models
self._capture_model_state()
self.update_parameters()
self.finalize_update()
Expand All @@ -55,9 +57,10 @@ def __init__(self, fd_config: FDConfig, model: nn.Layer):
@paddle.no_grad()
def _capture_model_state(self):
"""Capture and store initial model parameters state."""
for name, param in self.model.state_dict().items():
logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
self.state_dict[name] = param
for model in self.model_list:
for name, param in model.state_dict().items():
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
self.state_dict[name] = param

def update_parameters(self, pid: int = 0) -> None:
"""Core method to update model parameters based on strategy."""
Expand Down Expand Up @@ -137,8 +140,9 @@ def clear_parameters(self, pid: int = 0) -> None:

paddle.device.cuda.empty_cache()
# step2: release model weight
for param in self.model.state_dict().values():
param._clear_data()
for model in self.model_list:
for param in model.state_dict().values():
param._clear_data()

self._verify_parameters("clearance")

Expand Down
7 changes: 7 additions & 0 deletions fastdeploy/spec_decode/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,20 @@ def __init__(self, fd_config: FDConfig):
Init Speculative proposer
"""
fd_config.parallel_config.tp_group = None
fd_config.parallel_config.ep_group = None
self.fd_config = deepcopy(fd_config)
fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
fd_config.parallel_config.ep_group = dist.get_group(
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
)
self.fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
self.fd_config.parallel_config.ep_group = dist.get_group(
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
)
self.parallel_config = self.fd_config.parallel_config
self.model_config = self.fd_config.model_config
self.speculative_config = self.fd_config.speculative_config
Expand Down
37 changes: 29 additions & 8 deletions fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def _update_mtp_config(self, main_model):
"""
Update config for MTP from global config
"""
self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM"
self.forward_meta: ForwardMeta = None
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
self.speculative_config.sharing_model = main_model
self.model_config.num_hidden_layers = 1
self.model_config.model = self.speculative_config.model
Expand Down Expand Up @@ -169,6 +170,9 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False):
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
):
Expand All @@ -178,8 +182,8 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False):
self.num_main_model_layers + self.model_config.num_hidden_layers,
):
key_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
cache_kvs_list.append(key_cache)
value_cache = paddle.empty(shape=[], dtype=cache_type)
Expand All @@ -199,6 +203,17 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False):
fill_value=0,
dtype=cache_type,
)
if kv_cache_quant_type == "block_wise_fp8":
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.model_inputs["caches"] = list(self.cache_kvs.values())
for value in self.cache_kvs.values():
del value
Expand Down Expand Up @@ -430,11 +445,10 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
if "caches" not in self.model_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
# has_prefill_task = False
# has_decode_task = False

for i in range(req_len):
request = req_dicts[i]
logger.info(f"{i}th request-{request.request_id}: {request}")
logger.debug(f"{i}th request-{request.request_id}: {request}")
idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task
prefill_start_index = request.prefill_start_index
Expand Down Expand Up @@ -688,7 +702,7 @@ def _post_process(self, sampled_token_ids):
self.max_model_len,
self.model_inputs["substep"],
)
if self.role == "prefill":
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
mtp_save_first_token(
self.model_inputs["base_model_draft_tokens"],
self.model_inputs["not_need_stop"],
Expand Down Expand Up @@ -820,11 +834,18 @@ def _propose(self, step_use_cudagraph: bool = False):
)

if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(sampled_token_ids, 0)
paddle.distributed.broadcast(
sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)

self._post_process(sampled_token_ids)
if substep != self.num_model_steps - 1:
self._get_self_hidden_states(hidden_states)
else:
if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward()

def _get_self_hidden_states(self, hidden_states):
target_hidden_states = eagle_get_self_hidden_states(
Expand Down
Loading
Loading