Skip to content

feat(benchmark): HSTU E2E training benchmark suite with progressive optimizations#340

Merged
shijieliu merged 12 commits intoNVIDIA:mainfrom
JacoCheung:junzhang/benchmark_only
Apr 14, 2026
Merged

feat(benchmark): HSTU E2E training benchmark suite with progressive optimizations#340
shijieliu merged 12 commits intoNVIDIA:mainfrom
JacoCheung:junzhang/benchmark_only

Conversation

@JacoCheung
Copy link
Copy Markdown
Collaborator

@JacoCheung JacoCheung commented Apr 2, 2026

Summary

Add a comprehensive HSTU end-to-end training benchmark suite that measures the impact of progressive optimizations on H100 GPUs.

  • Benchmark infrastructure: automated experiment generation (generate_gin_config.py), SLURM submission scripts, result analysis and visualization tools
  • 5 progressive experiments: Baseline → Workload-Balanced Shuffler → CUTLASS Attention → Selective Recompute → Tensor Parallel (TP=2)
  • Results on 2×H100 nodes (16 GPUs): Baseline 1092 TFLOPS (6.38% MFU) → Best 3933 TFLOPS (22.96% MFU), 3.60× speedup
  • CUTLASS attention kernel MFU heatmap: batch sizes 1–128 × seqlens 64–4096, peak ~60% MFU on H100

Key code changes (non-benchmark)

  • perf: eliminate D2H sync in DP embedding forward and enable DP/MP overlap ([FEA] Overlap DP and MP embedding. #307)
  • perf: avoid D2H sync in _Split2DJaggedFunction by precomputing split lengths
  • fix: triton attention correctness for num_contextuals, empty batch handling
  • fix(tp): scale total_candidates_seq_len for TP and shuffler compatibility
  • fix: shuffler race condition, batch counter, batch all2all support
  • fix: num_loss_tokens() added to BaseBatch/GPTSIDBatch, sid_gr progress() signature updated
  • fix: watchdog log decoupled from MEM_DEBUG
  • fix: batch_allgather dense tensor padding guard for pre-padded tensors (BaseBatch dense tensor shape convention inconsistency with padded fields (num_candidates) #361)
  • style: apply pre-commit formatting (isort 5.12.0, black, autoflake, codespell)

Benchmark experiments

# Experiment TFLOPS MFU (%) Speedup
0 Baseline 1092 6.38 1.00×
1 +Shuffler 1667 9.73 1.53×
2 +CUTLASS 3933 22.96 3.60×
3 +Recompute 3919 22.88 3.59×
4 +TP=2 2880 16.81 2.64×

Known Issues

Test plan

Closes #307

CI

@JacoCheung JacoCheung requested review from jiashuy and shijieliu April 2, 2026 02:25
@JacoCheung JacoCheung marked this pull request as draft April 2, 2026 02:25
@JacoCheung
Copy link
Copy Markdown
Collaborator Author

Need to rectify the benchmark result once #313 is done

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 2, 2026

Greptile Summary

This PR delivers an HSTU E2E training benchmark suite (5 progressive experiments reaching 3.6× speedup on H100), alongside several correctness and performance fixes: Triton attention or/and|/& for proper tensor masking, DP/MP embedding overlap to eliminate D2H sync, batch_allgather zero-batch guard, shuffler NVTX instrumentation, and a StackDumpWatchdog/CudaMemoryWatchdog for hang detection.

  • The loss normalization refactor moves from a per-step DP all-reduce of (sum_loss, num_tokens) to a pre-forward world all-reduce of batch.num_loss_tokens(). The global_tokens denominator is intentionally world-wide (all ranks have unique data per the developer's note). However, for logging, tokens_logged accumulates this world-summed value while loss_logged is only DP-all-reduced at log time, causing avg_loss to appear tp_size times smaller in TP>1 runs — making cross-experiment loss comparisons in the benchmark results potentially misleading for Experiment 4.

Confidence Score: 5/5

  • Safe to merge; the one finding is a P2 logging inaccuracy that only affects reported avg_loss for TP=2 runs, not training correctness.
  • All key correctness fixes (Triton masking, batch_allgather empty-batch guard, DP/MP D2H sync elimination, labels allgather for TP) look correct. The global_tokens world all-reduce in backward is intentional and confirmed by the developer. The only unresolved finding is the avg_loss denominator mismatch for TP>1, which is a logging concern only — gradients and optimizer updates are unaffected.
  • examples/hstu/training/trainer/training.py and examples/sid_gr/training/trainer/training.py — avg_loss logging denominator for TP>1 configurations.

Important Files Changed

Filename Overview
examples/commons/pipeline/train_pipeline.py Refactors loss normalization: replaces per-step DP all-reduce of (loss, tokens) with a pre-forward world all-reduce of batch.num_loss_tokens(). Backward is local_loss_sum * dp_size / global_tokens. Functionally correct for TP=1; the global_tokens world all-reduce is deliberate per developer confirmation. Return signature changes to (local_loss_sum, global_tokens, output), introducing a None second element in eval mode — callers updated accordingly.
examples/hstu/training/trainer/training.py Adopts new (local_loss_sum, global_tokens, output) return tuple, adds MFU logging via cal_hstu_flops/cal_mfu, and wraps the training loop with watched_iter. Loss logging inconsistency: tokens_logged accumulates world-all-reduced values but loss_logged is only DP-all-reduced, causing avg_loss to be off by tp_size for TP>1.
examples/sid_gr/training/trainer/training.py Adopts new three-tuple return; same tokens_logged vs loss_logged all-reduce mismatch as hstu training.py — avg_loss will be off by tp_size for TP>1 configurations.
examples/commons/modules/embedding.py Adds OVERLAP_DP_MP env-var gating for DP/MP stream overlap. Eliminates D2H sync by pre-computing length_per_key from the KJT cache in compute_dp_length_per_key() and passing it through forward(). _build_output_dict() uses pure Python int arithmetic to slice embeddings without further D2H copies. Logic looks correct.
examples/commons/distributed/batch_allgather.py Introduces _elems_per_sample() to handle zero actual_batch_size during padding, fixing the crash that occurred when dividing by zero. Known pre-padded tensor issue (#361) explicitly tracked separately.
examples/hstu/ops/triton_ops/triton_hstu_attention.py Replaces Python short-circuit or/and with Triton element-wise
examples/commons/sequence_batch/batch.py Adds num_loss_tokens() base method to BaseBatch with sensible defaults; subclasses (HSTUBatch, GPTSIDBatch) override with task-specific counts. Used by train_pipeline.py to compute global token normalization before the forward pass.
examples/commons/utils/watchdog.py New StackDumpWatchdog and CudaMemoryWatchdog utilities. Stack dump watchdog fires when no iteration heartbeat occurs within timeout; CudaMemoryWatchdog triggers empty_cache() on fragmentation threshold. Daemon threads, proper shutdown guarding. Code is clean and well-documented.
examples/hstu/training/benchmark/scripts/generate_gin_config.py New CLI tool that generates gin configs by filling a parameterized template. Works correctly; has a stale comment in the template (low = 1 but comment says "256 is the minimum sequence length").
examples/commons/distributed/dmp_to_tp.py Switches labels gathering from gatherv_along_first_dim (gather-to-root) to keyed_jagged_tensor_allgather (all-ranks-get-data), which is correct for TP where all peers need the full labels. Type-gates on isinstance(batch.labels, KeyedJaggedTensor), safe since labels are consistently KJT in this codebase.

Sequence Diagram

sequenceDiagram
    participant DL as DataLoader
    participant PP as TrainPipeline
    participant WD as StackDumpWatchdog
    participant EM as ShardedEmbedding
    participant MS as MainStream
    participant SS as SideStream(DP)
    participant NC as NCCL(World)

    PP->>WD: "watched_iter(count(step), timeout=60s)"
    WD-->>PP: heartbeat on each step

    PP->>DL: next batch
    PP->>PP: batch.num_loss_tokens()
    PP->>NC: all_reduce(global_tokens) [world group]
    NC-->>PP: global_tokens (sum all ranks)

    PP->>EM: forward(kjt)
    EM->>EM: compute_dp_length_per_key() [main stream, no D2H]
    par DP embedding on side stream
        EM->>SS: DataParallelEmbeddingCollection(kjt, lpk)
        SS-->>EM: dp_embeddings
    and MP embedding on main stream
        EM->>MS: ModelParallelEmbeddingCollection(kjt)
        MS-->>EM: mp_embeddings (awaitable.wait())
    end
    EM->>MS: wait_stream(side_stream)
    EM-->>PP: merged embeddings dict

    PP->>PP: model_fwd → losses
    PP->>PP: "local_loss_sum = sum(losses)"
    PP->>PP: "(local_loss_sum * dp_size / global_tokens).backward()"
    PP->>PP: optimizer.step()
    PP->>WD: heartbeat
Loading

Reviews (5): Last reviewed commit: "docs: update E2E benchmark results (3933..." | Re-trigger Greptile

Comment thread examples/hstu/ops/fused_hstu_op.py Outdated
Comment thread examples/commons/distributed/batch_shuffler.py
Comment thread examples/commons/distributed/batch_shuffler.py
@JacoCheung JacoCheung force-pushed the junzhang/benchmark_only branch from b41488b to 3a8ad56 Compare April 2, 2026 03:05
@JacoCheung JacoCheung force-pushed the junzhang/benchmark_only branch from c1356f2 to 93b8cf4 Compare April 2, 2026 06:28
Comment thread examples/hstu/training/benchmark/figs/hstu_attn_mfu.png
Comment thread examples/hstu/training/benchmark/E2E_BENCHMARK.md Outdated
Comment thread examples/commons/utils/logger.py
@JacoCheung JacoCheung force-pushed the junzhang/benchmark_only branch 19 times, most recently from e2440fc to 8485813 Compare April 10, 2026 03:34
@JacoCheung JacoCheung force-pushed the junzhang/benchmark_only branch 8 times, most recently from fa2e12a to ee542c6 Compare April 13, 2026 10:18
@JacoCheung JacoCheung marked this pull request as ready for review April 13, 2026 10:18
@JacoCheung JacoCheung force-pushed the junzhang/benchmark_only branch 2 times, most recently from bbf6103 to e26061f Compare April 13, 2026 14:50
Comment thread examples/commons/pipeline/train_pipeline.py
JacoCheung and others added 11 commits April 14, 2026 00:58
…watchdog

- Fix Triton kernel mask from Python and/or to bitwise &/| (correctness)
- Fix torch.distributed.gather destination rank for non-default DP groups
- Handle 0-D tensor in num_contextuals via .view(-1)[0].item()
- Fix shuffler race condition, batch counter, and batch a2a support
- Scale total_candidates_seq_len for TP and shuffler
- Add collective watchdog for hang detection
- Avoid D2H sync in _Split2DJaggedFunction by precomputing split lengths

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Eliminate D2H sync in DP embedding forward and enable DP/MP overlap
- Optimize loss normalization: move global token count before forward,
  defer loss all_reduce to log intervals only
- Add MEM_DEBUG instrumentation for GPU physical memory tracking

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add comprehensive HSTU training benchmark infrastructure:
- Experiment generation via gin configs with configurable optimizations
  (CUTLASS, recompute, shuffler, caching, TP, value distribution)
- SLURM batch submission with remote clone isolation (submit_remote.sh)
- Automated result analysis, comparison plots, and MFU heatmaps
- CUTLASS attention kernel micro-benchmark (benchmark_hstu_attn_mfu.py)
- GPU memory watchdog, cache hit rate debug logging, TFLOPS/MFU utils
- Zipf and uniform value distribution support for embedding keys
- E2E_BENCHMARK.md documentation with results and optimization space

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…M_DEBUG

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ogress() signature

BaseBatch gets a default num_loss_tokens() (labels count or batch_size fallback).
GPTSIDBatch overrides with candidate-specific token counting.
sid_gr training.py updated to handle 3-value return from progress().

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@JacoCheung JacoCheung force-pushed the junzhang/benchmark_only branch from e26061f to 84c9087 Compare April 14, 2026 07:59
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@JacoCheung JacoCheung force-pushed the junzhang/benchmark_only branch from 84c9087 to a4b7e34 Compare April 14, 2026 09:08
@shijieliu shijieliu merged commit 49bfd3a into NVIDIA:main Apr 14, 2026
JacoCheung added a commit to JacoCheung/recsys-examples that referenced this pull request Apr 16, 2026
Update from 04df536 to 65bad42 which adds fake tensor implementations
for torch.export (hstu_ops_gpu.py). This was missing since PR NVIDIA#340
accidentally reverted the submodule pointer.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
shijieliu pushed a commit that referenced this pull request Apr 17, 2026
…363)

* fix: reduce Docker image layers to avoid overlay2 max depth limit

Aggressively merge RUN instructions in the Dockerfile to reduce total
layer count from ~126 to ~119. The inference image was hitting the
overlay2 128-layer limit ("failed to register layer: max depth
exceeded") on CI nodes.

devel stage: 8 RUN + 1 COPY -> 4 RUN + 1 COPY (-4 layers)
build stage: 4 RUN + 1 COPY -> 1 RUN + 1 COPY (-3 layers)
FBGEMM and TorchRec kept as separate layers for build cache efficiency.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* ci: add pull_request_target trigger for auto CI on PR open/sync

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Fix imports for fake ops wrapper used in expor

* fix: remove invalid import of hstu.hstu_ops_gpu

The module hstu.hstu_ops_gpu does not exist as a Python module.
The C++ source hstu_ops_gpu.cpp compiles into hstu/fbgemm_gpu_experimental_hstu.so,
not a separate hstu_ops_gpu submodule. This import was incorrectly added in PR #327
and causes ModuleNotFoundError in CI.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: update FBGEMM submodule to include hstu_ops_gpu.py fake impl

Update from 04df536 to 65bad42 which adds fake tensor implementations
for torch.export (hstu_ops_gpu.py). This was missing since PR #340
accidentally reverted the submodule pointer.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* ci: allow /build with flags by matching prefix instead of exact string

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* ci: remove pull_request_target trigger, keep only /build comment

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Junyi Qiu <junyiq@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEA] Overlap DP and MP embedding.

2 participants