Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/features/plas_attention.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ We selected a subset (longbook_sum_eng) from InfiniteBench as the performance ev
## Usage

```
export FD_ATTENTION_BACKEND="PLAS_ATTN"
export FD_ATTENTION_BACKEND="MOBA_ATTN"

python -m fastdeploy.entrypoints.openai.api_server
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
Expand All @@ -207,13 +207,13 @@ python -m fastdeploy.entrypoints.openai.api_server
--max-num-batched-tokens 8192 \
--max-model-len 131072 \
--max-num-seqs 32 \
--plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
--moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}'
```

**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `plas_attention_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations.
**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `moba_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations.

**Parameter Description:**

* Setting `FD_ATTENTION_BACKEND="PLAS_ATTN"` enables PLAS sparse attention.
* `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse.
* `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse.
* Setting `FD_ATTENTION_BACKEND="MOBA_ATTN"` enables MOBA sparse attention.
* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse.
* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse.
16 changes: 8 additions & 8 deletions docs/zh/features/plas_attention.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
<img src="images/plas_training_distill.png" alt="Attention Gate Module" width="60%">
</div>

* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个MLP层压缩每个K个块,生成一个具有代表性的低维表示: $K_c^T=W_{kp}K^T$ ,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。
* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个 MLP 层压缩每个 K 个块,生成一个具有代表性的低维表示:$K_c^T=W_{kp}K^T$,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。

* **Training Data**: 得益于模型架构和训练范式的高效性,我们的方法仅使用 10 亿个 token 进行训练,便实现了近乎无损的精度。训练数据源自内部构建的包含长文本和短文本的混合语料库,从而增强了模块对不同序列长度的适应性。

Expand All @@ -36,7 +36,7 @@

* **Prefill Toke Union**: 我们观察到相邻的查询标记倾向于选择相似的关键块。利用这种局部性,我们取连续 128 个查询标记选择的关键块的并集,并联合计算这些标记的稀疏注意力机制。

* **Decode Head Union**: 鉴于GQA在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。
* **Decode Head Union**: 鉴于 GQA 在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。

* **Top-K Selection**: 传统的 Top-k 算法基于排序或直接调用 Cub 库,会带来显著的运行时开销。为了缓解这个问题,我们实现了一个基于二分查找的近似 Top-k 选择算法,该算法在保持准确率的同时显著降低了延迟,最终实现了性能的显著提升。

Expand Down Expand Up @@ -200,7 +200,7 @@
## 使用方式

```
export FD_ATTENTION_BACKEND="PLAS_ATTN"
export FD_ATTENTION_BACKEND="MOBA_ATTN"

python -m fastdeploy.entrypoints.openai.api_server
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
Expand All @@ -211,13 +211,13 @@ python -m fastdeploy.entrypoints.openai.api_server
--max-num-batched-tokens 8192 \
--max-model-len 131072 \
--max-num-seqs 32 \
--plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
--moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}'
```

**Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`plas_attention_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化
**Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`moba_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化

**Parameter Description:**

* `FD_ATTENTION_BACKEND="PLAS_ATTN"` 启用 PLAS sparse attention.
* `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` 表示当encoder时,top-k的范围在50到60之间。
* `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` 表示当decoder时,top-k的范围在100到120之间。
* `FD_ATTENTION_BACKEND="MOBA_ATTN"` 启用 MOBA sparse attention.
* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` 表示当encoder时,top-k的范围在50到60之间。
* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` 表示当decoder时,top-k的范围在100到120之间。
74 changes: 37 additions & 37 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,63 +690,63 @@ def update_use_cudagraph(self, argument: bool):
argument = self.use_cudagraph


class PlasAttentionConfig:
class MobaAttentionConfig:
def __init__(
self,
args,
):
self.plas_encoder_top_k_left: int = None
self.plas_encoder_top_k_right: int = None
"The sparse topk of encoder attention is located at [plas_encoder_top_k_left, plas_encoder top_k_right]"
self.plas_decoder_top_k_left: int = None
self.plas_decoder_top_k_right: int = None
"The sparse topk of decoder attention is located at [plas_decoder_top_k_left, plas_decoder top_k_right]"
self.plas_use_encoder_seq_limit: int = None
"When the number of encdoer token is less than plas_use_encoder_seq_limit, it is not sparse"
self.plas_use_decoder_seq_limit: int = None
"When the number of decdoer token is less than plas_use_decoder_seq_limit, it is not sparse"
self.plas_block_size: int = 128
self.mlp_weight_name: str = "plas_attention_mlp_weight.safetensors"
self.plas_max_seq_length: int = 128 * 1024
self.moba_encoder_top_k_left: int = None
self.moba_encoder_top_k_right: int = None
"The sparse topk of encoder attention is located at [moba_encoder_top_k_left, moba_encoder top_k_right]"
self.moba_decoder_top_k_left: int = None
self.moba_decoder_top_k_right: int = None
"The sparse topk of decoder attention is located at [moba_decoder_top_k_left, moba_decoder top_k_right]"
self.moba_use_encoder_seq_limit: int = None
"When the number of encdoer token is less than moba_use_encoder_seq_limit, it is not sparse"
self.moba_use_decoder_seq_limit: int = None
"When the number of decdoer token is less than moba_use_decoder_seq_limit, it is not sparse"
self.moba_block_size: int = 128
self.mlp_weight_name: str = "moba_mlp_weight.safetensors"
self.moba_max_seq_length: int = 128 * 1024
if args is not None:
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
if self.plas_use_encoder_seq_limit is None and self.plas_encoder_top_k_left is not None:
self.plas_use_encoder_seq_limit = self.plas_encoder_top_k_left * self.plas_block_size
if self.plas_use_decoder_seq_limit is None and self.plas_decoder_top_k_left is not None:
self.plas_use_decoder_seq_limit = self.plas_decoder_top_k_left * self.plas_block_size
if self.moba_use_encoder_seq_limit is None and self.moba_encoder_top_k_left is not None:
self.moba_use_encoder_seq_limit = self.moba_encoder_top_k_left * self.moba_block_size
if self.moba_use_decoder_seq_limit is None and self.moba_decoder_top_k_left is not None:
self.moba_use_decoder_seq_limit = self.moba_decoder_top_k_left * self.moba_block_size
self.check_legality_parameters()

def check_legality_parameters(
self,
) -> None:
if self.plas_encoder_top_k_left is not None:
assert self.plas_encoder_top_k_left > 0, "plas_encoder_top_k_left must large than 0"
if self.moba_encoder_top_k_left is not None:
assert self.moba_encoder_top_k_left > 0, "moba_encoder_top_k_left must large than 0"

if self.plas_encoder_top_k_right is not None:
assert self.plas_encoder_top_k_right > 0, "plas_encoder_top_k_right must large than 0"
if self.moba_encoder_top_k_right is not None:
assert self.moba_encoder_top_k_right > 0, "moba_encoder_top_k_right must large than 0"
assert (
self.plas_encoder_top_k_right >= self.plas_encoder_top_k_left
), "plas_encoder_top_k_right must large than plas_encoder_top_k_left"
self.moba_encoder_top_k_right >= self.moba_encoder_top_k_left
), "moba_encoder_top_k_right must large than moba_encoder_top_k_left"

if self.plas_decoder_top_k_left is not None:
assert self.plas_decoder_top_k_left > 0, "plas_decoder_top_k_left must large than 0"
if self.moba_decoder_top_k_left is not None:
assert self.moba_decoder_top_k_left > 0, "moba_decoder_top_k_left must large than 0"

if self.plas_decoder_top_k_right is not None:
assert self.plas_decoder_top_k_right > 0, "plas_decoder_top_k_right must large than 0"
if self.moba_decoder_top_k_right is not None:
assert self.moba_decoder_top_k_right > 0, "moba_decoder_top_k_right must large than 0"
assert (
self.plas_decoder_top_k_right >= self.plas_decoder_top_k_left
), "plas_decoder_top_k_right must large than plas_decoder_top_k_left"
self.moba_decoder_top_k_right >= self.moba_decoder_top_k_left
), "moba_decoder_top_k_right must large than moba_decoder_top_k_left"

if self.plas_use_encoder_seq_limit is not None and self.plas_encoder_top_k_left is not None:
assert self.plas_use_encoder_seq_limit >= self.plas_encoder_top_k_left * self.plas_block_size
if self.plas_use_decoder_seq_limit is not None and self.plas_decoder_top_k_left is not None:
assert self.plas_use_decoder_seq_limit >= self.plas_decoder_top_k_left * self.plas_block_size
if self.moba_use_encoder_seq_limit is not None and self.moba_encoder_top_k_left is not None:
assert self.moba_use_encoder_seq_limit >= self.moba_encoder_top_k_left * self.moba_block_size
if self.moba_use_decoder_seq_limit is not None and self.moba_decoder_top_k_left is not None:
assert self.moba_use_decoder_seq_limit >= self.moba_decoder_top_k_left * self.moba_block_size

def to_json_string(self):
"""
Convert plas_attention_config to json string.
Convert moba_attention_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})

Expand Down Expand Up @@ -1105,7 +1105,7 @@ def __init__(
decoding_config: DecodingConfig = None,
quant_config: QuantConfigBase = None,
graph_opt_config: GraphOptimizationConfig = None,
plas_attention_config: PlasAttentionConfig = None,
moba_attention_config: MobaAttentionConfig = None,
speculative_config: SpeculativeConfig = None,
tokenizer: str = None,
max_model_len: int = 8192,
Expand Down Expand Up @@ -1140,7 +1140,7 @@ def __init__(
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
self.decoding_config: DecodingConfig = decoding_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config
# Initialize cuda graph capture list
if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
Expand Down
26 changes: 13 additions & 13 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
FDConfig,
GraphOptimizationConfig,
LoadConfig,
MobaAttentionConfig,
ModelConfig,
ParallelConfig,
PlasAttentionConfig,
SpeculativeConfig,
TaskOption,
)
Expand Down Expand Up @@ -342,9 +342,9 @@ class EngineArgs:
"""
Configuration for graph optimization backend execution.
"""
plas_attention_config: Optional[Dict[str, Any]] = None
moba_attention_config: Optional[Dict[str, Any]] = None
"""
Configuration for plas attention.
Configuration for moba attention.
"""

enable_logprob: bool = False
Expand Down Expand Up @@ -559,9 +559,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="",
)
model_group.add_argument(
"--plas-attention-config",
"--moba-attention-config",
type=json.loads,
default=EngineArgs.plas_attention_config,
default=EngineArgs.moba_attention_config,
help="",
)
model_group.add_argument(
Expand Down Expand Up @@ -959,17 +959,17 @@ def create_graph_optimization_config(self) -> GraphOptimizationConfig:
graph_optimization_args[k] = v
return GraphOptimizationConfig(graph_optimization_args)

def create_plas_attention_config(self) -> PlasAttentionConfig:
def create_moba_attention_config(self) -> MobaAttentionConfig:
"""
Create and retuan a PlasAttentionConfig object based on the current settings.
Create and retuan a MobaAttentionConfig object based on the current settings.
"""
attention_args = asdict(self)
if self.plas_attention_config is not None:
for k, v in self.plas_attention_config.items():
if self.moba_attention_config is not None:
for k, v in self.moba_attention_config.items():
attention_args[k] = v
return PlasAttentionConfig(attention_args)
return MobaAttentionConfig(attention_args)
else:
return PlasAttentionConfig(None)
return MobaAttentionConfig(None)

def create_early_stop_config(self) -> EarlyStopConfig:
"""
Expand Down Expand Up @@ -1025,7 +1025,7 @@ def create_engine_config(self) -> FDConfig:
scheduler_cfg = self.create_scheduler_config()
graph_opt_cfg = self.create_graph_optimization_config()
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
plas_attention_config = self.create_plas_attention_config()
moba_attention_config = self.create_moba_attention_config()

early_stop_cfg = self.create_early_stop_config()
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
Expand Down Expand Up @@ -1063,7 +1063,7 @@ def create_engine_config(self) -> FDConfig:
max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold,
graph_opt_config=graph_opt_cfg,
plas_attention_config=plas_attention_config,
moba_attention_config=moba_attention_config,
guided_decoding_backend=self.guided_decoding_backend,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
early_stop_config=early_stop_cfg,
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def _start_worker_service(self):
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
f" --reasoning_parser {self.cfg.reasoning_parser}"
f" --load_choices {self.cfg.load_config.load_choices}"
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
f" --ips {ips}"
)

Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/model_executor/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .flash_attn_backend import FlashAttentionBackend
from .iluvatar_attn_backend import IluvatarAttnBackend
from .mla_attention_backend import MLAAttentionBackend
from .moba_attention_backend import PlasAttentionBackend
from .moba_attention_backend import MobaAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend
from .xpu_attn_backend import XPUAttentionBackend

Expand All @@ -35,5 +35,5 @@
"IluvatarAttnBackend",
"BlockAttentionBackend",
"Attention",
"PlasAttentionBackend",
"MobaAttentionBackend",
]
Loading
Loading