Skip to content

[Feature]【Hackathon 10th Spring No.47】Add MiniMax-M1 integration tests and multi-GPU support#7511

Open
bobby-cloudforge wants to merge 7 commits intoPaddlePaddle:developfrom
CloudForge-Solutions:task/h10-047-minimax-m1-integration1
Open

[Feature]【Hackathon 10th Spring No.47】Add MiniMax-M1 integration tests and multi-GPU support#7511
bobby-cloudforge wants to merge 7 commits intoPaddlePaddle:developfrom
CloudForge-Solutions:task/h10-047-minimax-m1-integration1

Conversation

@bobby-cloudforge
Copy link
Copy Markdown

Motivation

Companion to the MiniMax-M1 model PR — adds integration tests and multi-GPU validation infrastructure for Hackathon 10th Spring No.47.

Modifications

Integration Tests (tests/model_executor/test_minimax_m1_integration.py)

  • End-to-end construction + forward pass tests with full model config
  • Multi-layer interaction tests (linear + full attention)
  • Weight loading validation (v0 and v1 paths)

Multi-GPU Validation Script (scripts/validate_minimax_m1_multigpu.sh)

  • Automated tensor-parallel validation script for 2/4/8 GPU configurations
  • Includes correctness checks and basic throughput measurement

Test Infrastructure (tests/model_executor/conftest.py)

  • Shared fixtures for model executor tests
  • Config builder helpers for MiniMax-M1 test variants

Model Base Extension (fastdeploy/model_executor/models/model_base.py)

  • Minor extension to support MiniMax-M1 linear attention state management

Usage or Command

# Run integration tests
pytest tests/model_executor/test_minimax_m1_integration.py -v

# Multi-GPU validation (requires 8 GPUs)
bash scripts/validate_minimax_m1_multigpu.sh

Accuracy Tests

Integration tests verify:

  • Model construction with correct layer type dispatch (linear vs full attention)
  • Forward pass shape correctness through mixed attention pipeline
  • Weight loading key mapping for both v0 and v1 loaders
  • DeepNorm scaling coefficients applied correctly

All tests use monkeypatch.setattr + real objects (no MagicMock).

Checklist

  • Integration tests for mixed attention pipeline
  • Multi-GPU validation script
  • Shared test fixtures
  • Pre-commit hooks passing

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 20, 2026

Thanks for your contribution!

@paddle-bot paddle-bot Bot added the contributor External developers label Apr 20, 2026
- scripts/validate_minimax_m1_multigpu.sh: fix Tier 2 RESPONSE not reaching
  Python (use env var instead of stdin); pipe $MODELS via stdin in Tier 1
  to avoid triple-quote injection; use jq in send_chat for safe JSON
- model_base.py: warn on architecture registration overwrite
- lightning_attn.py: use None + conditional add instead of int 0 accumulator
PaddlePaddle-bot

This comment was marked as outdated.

…onvention

- scripts/validate_minimax_m1_multigpu.sh: add 'import sys' to Tier 2
  Python heredoc (sys.exit used at lines 213/221/248/253)
- model_base.py: replace stdlib logging with paddleformers.utils.log.logger
  to match project convention (17/17 model files use this pattern)
PaddlePaddle-bot

This comment was marked as outdated.

Addresses PaddlePaddle-bot review: _fwd_none_diag_kernel uses
tl.program_id(2) for feature block indexing but the grid is 2D.
Current NUM_FBLOCK=1 makes off_e=0 correct, but the assumption
is implicit. The assertion documents this and will fail loudly
if NUM_FBLOCK is changed in the future.
PaddlePaddle-bot

This comment was marked as outdated.

…s for linear attention

- MiniMaxM1LinearAttention: emit logger.warning when _kv_history is reset
  due to batch_size change (continuous batching scenario). Documents the
  known limitation referenced by the existing TODO and surfaces the state
  loss to operators instead of silently dropping accumulated state.

- _get_tensor_parallel_mappings: add output_gate (column) and out_proj
  (row) entries so the v0 (PaddleFormers / load_weight_utils) loader path
  splits linear-attention specific weights correctly. Missing keys for
  full-attention layers are silently ignored by paddleformers, so the
  added entries are safe for all 80 layers.
PaddlePaddle-bot

This comment was marked as outdated.

@bobby-cloudforge bobby-cloudforge changed the title [Feature]【Hackathon 10th Spring No.47】Add MiniMax-M1 integration tests and multi-GPU support [Models] Add MiniMax-M1 hybrid attention model with Lightning Attention Triton kernel Apr 23, 2026
@bobby-cloudforge bobby-cloudforge changed the title [Models] Add MiniMax-M1 hybrid attention model with Lightning Attention Triton kernel [Feature]【Hackathon 10th Spring No.47】Add MiniMax-M1 integration tests and multi-GPU support Apr 23, 2026
…istory offline-only docstring, attn_type_list fallback caveat

- Rename clear_grpah_opt_backend → clear_graph_opt_backend (matches qwen2/qwen3/glm4_moe/ernie4_5_moe convention; no upstream callers)
- Add inline comment at MiniMaxM1LinearAttention.forward L400 explaining RMSNorm.forward always returns (out, residual_out) tuple — [0] is intentional, not a bug
- Strengthen MiniMaxM1LinearAttention class docstring with explicit OFFLINE/SINGLE-REQUEST-ONLY warning for _kv_history (continuous batching unsafe even under unchanged batch_size)
- Document _build_attn_type_list as MiniMax-Text-01 / M1 80-layer fallback only; both call sites already prefer config-driven attn_type_list via getattr
PaddlePaddle-bot

This comment was marked as outdated.

Round 6 bot review fixes:

- minimax_m1.py: replace slope_rate.squeeze(-1) with .reshape([-1]).
  squeeze(-1) on [heads,1,1] yields [heads,1] (ndim=2); lightning_attention
  only reshapes ndim==1 tensors, so [heads,1] would reach the Triton kernel
  with an incorrect shape and fail with a broadcast/shape assertion.
  .reshape([-1]) always produces 1-D [heads].

- lightning_attn.py: add assert output is not None after the accumulation
  loop.  When d < m the range(n-1) loop is empty and output stays None;
  the subsequent .transpose() in the caller would crash with AttributeError.
  The d % m == 0 assert upstream already blocks this for well-formed inputs,
  but an explicit guard improves error clarity.

- minimax_m1.py: clarify kv_cache_shape docstring to state it is a
  per-slot shape (no batch dim), distinguishing it from the 4-D runtime
  shape of _kv_history ([batch, heads, head_dim, head_dim]).
Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-24 13:48:12

📋 Review 摘要

PR 概述:为 MiniMax-M1 混合注意力模型添加核心实现、集成测试和多卡验证脚本(Hackathon 10th Spring No.47)。
变更范围model_executor/models/minimax_m1.py(新建,1009行)、lightning_attn.py(新建 Triton kernel)、model_base.py(注册逻辑扩展)、测试/文档/脚本
影响面 TagModels OP


📝 PR 规范检查

PR 标题包含有效 Tag [Feature],Motivation/Modifications 填写完整,规范符合要求。


问题

级别 文件 概述
🟡 建议 model_base.py:315 架构注册逻辑存在冗余写入,热重载时会触发 spurious warning
🟡 建议 minimax_m1.py _kv_history 仅检测 batch_size 变化,连续批处理下静默泄漏跨请求状态
🟡 建议 lightning_attn.py 生产代码中使用 assert 语句,会被 -O 优化 flag 禁用
❓ 疑问 minimax_m1.py MiniMaxM1PretrainedModel 未注册到 ModelRegistry,与 FD 服务路径脱节

🟡 建议 — minimax_m1.py_kv_history 静默跨请求状态污染

MiniMaxM1LinearAttention.forward()_kv_history 存储为实例变量,当前 needs_init 条件仅在 batch_size 改变时重置状态。

问题:在 continuous batching 场景下,同一 batch_size 可以包含完全不同的请求(旧请求完成、新请求进入),旧请求的线性注意力递归状态会静默泄漏到新请求,产生错误输出且无任何日志告警。

虽然 docstring 和 PR 描述均已标注"Offline / single-request use only"和 TODO,但 FastDeploy 的生产 serving 默认走 continuous batching,建议在 forward 入口添加保护:

# 临时守卫:serving 模式下主动报错,避免静默错误
if forward_meta is not None and getattr(forward_meta, 'is_serving', False):
    raise NotImplementedError(
        "MiniMaxM1LinearAttention: _kv_history is not isolated per-request. "
        "Migrate to slot-based cache before production serving."
    )

或在 model_base.py/ModelRegistry 侧增加 serving_safe=False 标志,让调度层在加载该模型时自动拒绝多请求并发。


🟡 建议 — lightning_attn.pyassert 不适合生产 Kernel 代码

lightning_attention() 函数中使用了 assert output is not Noneassert NUM_FBLOCK == 1 做运行时检查。assert 在 Python -O 模式下会被完全剥除,导致错误条件静默通过。

建议替换为显式异常:

# 替换 assert output is not None
if output is None:
    raise RuntimeError("lightning_attention: output tensor is None; check kernel launch parameters.")

# 替换 assert NUM_FBLOCK == 1
if NUM_FBLOCK != 1:
    raise NotImplementedError(f"lightning_attention: NUM_FBLOCK={NUM_FBLOCK} > 1 not supported.")

❓ 疑问 — minimax_m1.pyMiniMaxM1PretrainedModel 与 FD 架构脱节

文件末尾定义了 MiniMaxM1PretrainedModel(PretrainedModel),该类:

  • 继承自 PaddleFormers 的 PretrainedModel(而非 FD 的 ModelForCasualLM
  • 没有 @ModelRegistry.register_model_class 装饰器
  • 未实现 FD 核心接口 set_state_dict / load_weights
  • _get_tensor_parallel_mappings 中引用了 config.fuse_attention_qkv,但该字段不在 FDConfig 中,运行时可能 AttributeError

请确认此类是否在 FD serving 路径中实际使用?若为过渡期遗留代码,建议移除或加 # NOTE: PaddleFormers-only, not used in FD serving 注释,避免维护混淆。


总体评价

代码整体质量较高,注释详尽、TP 分片逻辑正确,测试覆盖了主要路径。核心阻塞点是 _kv_history 的跨请求状态隔离问题——虽已文档化,但在正式合入 develop 前建议明确设置运行时守卫或在 issue 跟踪后续 PR,避免意外被生产流量触发。

if architecture:
if architecture in cls._arch_to_model_cls and cls._arch_to_model_cls[architecture] is not model_cls:
logger.warning("Overwriting model registration for architecture '%s'", architecture)
cls._arch_to_model_cls[architecture] = model_cls
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 架构注册逻辑存在冗余写入与热重载 spurious warning 风险

architecture == model_cls.name()(如 MiniMaxM1ForCausalLM),同一 key 会被 cls._arch_to_model_cls[model_cls.name()](第311行)和此处 cls._arch_to_model_cls[architecture] 各写一次,产生冗余赋值。

更重要的是,模块热重载时新旧 class 对象不同(is not model_cls 为 True),会触发 spurious warning,影响研发体验。建议改为按类名判断而非对象 identity:

if architecture and architecture in cls._arch_to_model_cls:
    existing = cls._arch_to_model_cls[architecture]
    if getattr(existing, '__name__', None) != model_cls.__name__:
        logger.warning(
            "Overwriting model registration for architecture '%s' "
            "(old=%s, new=%s)",
            architecture, getattr(existing, '__name__', existing), model_cls.__name__
        )
cls._arch_to_model_cls[architecture] = model_cls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants