Support subquadratic-ops kernels in evo2 autoregressive inference#1565
Conversation
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
…eline Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
📝 WalkthroughWalkthroughThis pull request introduces Changes
Sequence Diagram(s)sequenceDiagram
participant User as User/CLI
participant Infer as infer()
participant Setup as setup_inference_engine()
participant Provider as ModelProvider
participant Mixer as HyenaMixer.forward()
participant Engine as parallel_fir()
participant SubqOps as subquadratic_ops_torch
participant BaseOps as F.conv1d/FFTConv
User->>Infer: --use-subquadratic-ops flag
Infer->>Setup: use_subquadratic_ops=True
Setup->>Provider: set use_subquadratic_ops
rect rgb(100, 150, 200, 0.5)
Note over Infer,Provider: Prefill Phase
Infer->>Mixer: forward(prefill_tokens)
Mixer->>Engine: parallel_fir(..., use_subquadratic_ops=True)
alt use_subquadratic_ops enabled
Engine->>SubqOps: _subq_fft_causal_conv1d (fir_len ≥ 128)
SubqOps-->>Engine: optimized output
else standard path
Engine->>BaseOps: fftconv_func/F.conv1d
BaseOps-->>Engine: output
end
Engine-->>Mixer: FIR result + inference_state
Mixer->>Mixer: _populate_b2b_inference_state()
Mixer-->>Infer: prefill hidden states
end
rect rgb(200, 150, 100, 0.5)
Note over Infer,Mixer: Decode Phase (unchanged)
Infer->>Mixer: forward(next_token)
Mixer->>Mixer: use cached inference_state
Mixer-->>Infer: decoded token
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Review rate limit: 9/10 reviews remaining, refill in 6 minutes. Comment |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py (2)
524-562: ⚡ Quick winCover the prompt-segmentation handoff too.
This only asserts equivalence for a single-shot prefill. The new cache-seeding logic is most likely to regress when
--prompt-segmentation-thresholdforces a prefill→decode handoff inside the prompt, so I'd add one companion case with a threshold belowlen(PROMPT_1)and require the same equality there.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py` around lines 524 - 562, Add a companion assertion that verifies subquadratic ops match baseline when prompt segmentation forces a prefill→decode handoff: call run_infer_subprocess twice (once with use_subquadratic_ops=False and once with True) using the same mbridge_checkpoint_path, PROMPT_1, seed, top_k=1, etc., but pass prompt_segmentation_threshold set to a value smaller than len(PROMPT_1) (e.g., 1 less than the prompt length) and the same other args (max_new_tokens, temperature), then assert both outputs are non-empty and equal (use the same failure message pattern); reference run_infer_subprocess, PROMPT_1, and prompt_segmentation_threshold to locate where to add this companion case.
306-334: ⚡ Quick winPoint this helper at the source
infer.py.
run_infer_subprocess_parallel()already shells out to_infer_script_path()withsrc/prepended toPYTHONPATH, but this helper still uses-m bionemo.evo2.run.infer. That means the new subquadratic regression test can validate an older installed package instead of the code under review.Suggested fix
- "-m", - "bionemo.evo2.run.infer", + str(_infer_script_path()), "--ckpt-dir", str(mbridge_checkpoint_path), "--prompt", prompt, @@ - env = copy.deepcopy(PRETEST_ENV) + env = copy.deepcopy(PRETEST_ENV) + src_dir = str(_recipe_root() / "src") + env["PYTHONPATH"] = src_dir + os.pathsep + env.get("PYTHONPATH", "")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py` around lines 306 - 334, The helper builds cmd using "-m bionemo.evo2.run.infer", which lets the installed package be executed instead of the local source; change the command construction in run_infer_subprocess_parallel() to invoke the actual source script returned by _infer_script_path() (replace the "-m bionemo.evo2.run.infer" element with str(_infer_script_path()) and ensure the subprocess uses the same Python executable so the prepended "src/" PYTHONPATH is applied), leaving the rest of the args (ckpt-dir, prompt, etc.) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py`:
- Around line 381-385: The mixer FIR cache is seeded incorrectly: the helper
always stores x2 * v in mixer_input_tail even though
ParallelShortHyenaOperator.forward uses v (not x2*v) when self.mixer.pregate is
False. Update the logic that sets mixer_input_tail (currently "mixer_input_tail
= (x2 * v).contiguous()") to check self.mixer.pregate and store v.contiguous()
when pregate is False, otherwise keep (x2 * v).contiguous(); reference
ParallelShortHyenaOperator.forward, the mixer.pregate flag, and variables x2, v,
and mixer_input_tail to locate the change.
---
Nitpick comments:
In `@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py`:
- Around line 524-562: Add a companion assertion that verifies subquadratic ops
match baseline when prompt segmentation forces a prefill→decode handoff: call
run_infer_subprocess twice (once with use_subquadratic_ops=False and once with
True) using the same mbridge_checkpoint_path, PROMPT_1, seed, top_k=1, etc., but
pass prompt_segmentation_threshold set to a value smaller than len(PROMPT_1)
(e.g., 1 less than the prompt length) and the same other args (max_new_tokens,
temperature), then assert both outputs are non-empty and equal (use the same
failure message pattern); reference run_infer_subprocess, PROMPT_1, and
prompt_segmentation_threshold to locate where to add this companion case.
- Around line 306-334: The helper builds cmd using "-m bionemo.evo2.run.infer",
which lets the installed package be executed instead of the local source; change
the command construction in run_infer_subprocess_parallel() to invoke the actual
source script returned by _infer_script_path() (replace the "-m
bionemo.evo2.run.infer" element with str(_infer_script_path()) and ensure the
subprocess uses the same Python executable so the prepended "src/" PYTHONPATH is
applied), leaving the rest of the args (ckpt-dir, prompt, etc.) unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 0e318f54-a8af-4548-aa42-88bf6f61e018
📒 Files selected for processing (7)
bionemo-recipes/recipes/evo2_megatron/README.mdbionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.py
|
/ok to test 05e1356 |
@farhadrgh, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/1/ |
|
/ok to test 05e1356 |
|
/ok to test 05e1356 |
|
LGTM @farhadrgh |
Description
Closes the gap noted in
hyena_mixer.py(# todo: support inference_context for b2b_kernel) and the README caveat that--use-subquadratic-ops"does not apply to autoregressive inference (infer_evo2)". After this PR, the same fused kernels that accelerate training and batch prediction also accelerate the prefill phase of autoregressive inference.Summary of change:
engine.parallel_firnow acceptsuse_subquadratic_opsand routes tofft_causal_conv1d(filters ≥ 128) orcausal_conv1d(short filters), wired through both call sites inhyena_utils.py.HyenaMixer.forwarddetects prefill (no FIR cache yet) and runsb2b_causal_conv1dfor the fused proj+mixer convolution. The kernel doesn't expose its intermediate, so we run a tiny windowed proj-conv on the lastK_proj + K_mixer − 2input positions to materialize the(x2*v)tail and seed the mixer's FIR cache. Works for bothhyena_short_convandhyena_medium_conv.del self._parameters["short_conv_weight"]micro-optimization inParallelCausalDepthwiseConv1dWithState._get_weight()—B2BCausalConv1dModulereads that raw param on every prefill, so deleting it after first decode broke multi-prompt inference. Memory cost is ~4 MB for a 1B model.infer_evo2gets a--use-subquadratic-opsflag.Testing
test_forward_manual[1b-8k-bf16-subquadratic-ops-flash]covers the(flash_decode=True, subquadratic_ops=True)combination that was previously skipped.test_subquadratic_ops_matches_baselineruns greedy autoregressive generation with and without--use-subquadratic-opsand asserts identical output — this is the strict check that Phase 2 state population is correct (a wrong cache would diverge during decode).test_hyena_mixer_kernel.py) and inference-context unit tests pass unchanged.Performance
infer_evo2, evo2/1b-8k-bf16, single A6000, multiple identical prompts in one process to amortize the one-time JIT compile cost (~15 s the first time each subq-ops kernel sees a new shape). Steady-state numbers from batches 3+:The speedup is concentrated in prefill. The relative improvement grows with prompt length and shrinks as more decode tokens are amortized in.
Type of changes
CI Pipeline Configuration
Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.
Unit tests marked as
@pytest.mark.multi_gpuor@pytest.mark.distributedare not run in the PR pipeline.For more details, see CONTRIBUTING
Note
By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.
Authorizing CI Runs
We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.
automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
/ok to testcomment on the pull request to trigger CI. This will need to be done for each new commit.Triggering Code Rabbit AI Review
To trigger a code review from code rabbit, comment on a pull request with one of these commands:
See https://docs.coderabbit.ai/reference/review-commands for a full list of commands.
Pre-submit Checklist
Summary by CodeRabbit
New Features
--use-subquadratic-opsCLI option to optimize prompt/prefill processing during inference while leaving per-token decode unchanged.Documentation
Tests