Skip to content

Conversation

@RSchwan
Copy link
Contributor

@RSchwan RSchwan commented Oct 11, 2025

Description

This PR adds alpha and beta scalings to tile_matmul as in the cublasdx API. The main motivation is to generalize the tile_matmul API with the aim to reduce the number of operations and shared memory of otherwise necessary operations. For example, in the tile_block_cholesky example

@@ -78,8 +78,7 @@ def blocked_cholesky_kernel(
                 for j in range(0, k, block_size):
                     L_block = wp.tile_load(L, shape=(block_size, block_size), offset=(k, j))
                     L_block_T = wp.tile_transpose(L_block)
-                    L_L_T_block = wp.tile_matmul(L_block, L_block_T)
-                    A_kk_tile -= L_L_T_block
+                    wp.tile_matmul(L_block, L_block_T, A_kk_tile, alpha=-1.0)

this reduces shared memory usage (L_L_T_block doesn't exist) and fuses the subtraction.

Before your PR is "Ready for review"

  • All commits are signed-off to indicate that your contribution adheres to the Developer Certificate of Origin requirements
  • Necessary tests have been added
  • Documentation is up-to-date
  • Auto-generated files modified by compiling Warp and building the documentation have been updated (e.g. __init__.pyi, functions.rst)
  • Code passes formatting and linting checks with pre-commit run -a

Summary by CodeRabbit

  • New Features

    • tile_matmul adds alpha and beta scalars: out = alpha*(A @ B) + beta*out (defaults 1.0); supports in-place scaled matmul.
  • Documentation

    • API docs and overloads updated to describe alpha/beta, defaults, semantics, and note adjoint support is not yet available.
  • Examples

    • Cholesky example simplified to use tile_matmul(alpha=-1.0) for in-place subtraction.
  • Tests

    • Tests updated for alpha/beta forward behavior and adjusted gradient expectations.

@coderabbitai
Copy link

coderabbitai bot commented Oct 11, 2025

📝 Walkthrough

Walkthrough

Add explicit scalar parameters alpha and beta to wp.tile_matmul and propagate them through builtins, LTO/dispatch, CPU/MathDx/CUDA backends, and native C++ tile/scalar implementations; update docs, stubs, examples, and tests. Treat float immediates as SSA constants in codegen and extend adjoint/backward paths to account for beta scaling.

Changes

Cohort / File(s) Summary
Changelog & Docs
CHANGELOG.md, docs/modules/functions.rst
Document new alpha/beta parameters, defaults, and updated computation semantics out = alpha * a@b + beta * out; add param docs and note adjoint limitations.
Public Stubs
warp/__init__.pyi
Update tile_matmul overloads/signatures and docstrings to expose alpha (and beta where applicable) and reflect new accumulation formula.
Builtins & Dispatch
warp/_src/builtins.py
Extend builtin registration and LTO dispatch for tile_matmul to accept alpha and beta (defaults 1.0); thread scalars through LTO/CPU/MathDx/CUDA paths; set beta=0.0 for overwrite paths; update emitted args/template handling and docstrings.
Codegen
warp/_src/codegen.py
Treat Python float immediates as SSA constants in register_var by emitting constants via adj.add_constant(var).
Native Backend (C++)
warp/native/tile.h
Add wp::remove_reference helper; change scalar_matmul, tile_matmul, and adj_tile_matmul signatures to accept alpha/beta (and adjoint alpha/beta); compute C = alpha * sum + beta * C; propagate scalings through forward and backward paths and adjust adjoint scaling for adj_C.
Examples (Cholesky)
warp/examples/tile/example_tile_block_cholesky.py
Replace explicit subtraction-of-product sequences with in-place wp.tile_matmul(..., alpha=-1.0) calls at block update sites.
Tests (Cholesky)
warp/tests/tile/test_tile_cholesky.py
Mirror example changes: use tile_matmul(..., alpha=-1.0) for in-place subtraction updates in kernels.
Tests (MathDx)
warp/tests/tile/test_tile_mathdx.py
Update test kernels to initialize C with prior content and call tile_matmul with explicit alpha/beta (e.g., alpha=0.5, beta=-1.3); update forward/backward expected values and gradient formulas to include scalings.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Kernel
  participant Builtins as builtins.tile_matmul
  participant LTO as LTO Dispatch
  participant Backend as CPU/MathDx/CUDA
  participant Native as native::tile_matmul

  Kernel->>Builtins: tile_matmul(a, b, out, alpha?, beta?)
  Builtins->>LTO: args (a, b, out, alpha, beta)
  Note right of Builtins: overwrite path -> beta = 0.0\notherwise use provided beta
  LTO->>Backend: call(a, b, out, alpha, beta)
  Backend->>Native: tile_matmul(..., alpha, beta)
  Native->>Native: C = alpha * (A @ B) + beta * C
  Native-->>Backend: out
  Backend-->>Kernel: out

  rect rgba(230,245,255,0.5)
    note over Builtins,Native: alpha/beta threaded through all execution and adjoint paths
  end
Loading
sequenceDiagram
  autonumber
  participant AD as Autodiff Engine
  participant NativeAdj as native::adj_tile_matmul

  AD->>NativeAdj: backprop(adj_C, alpha, beta)
  Note right of NativeAdj: scale A/B grads by alpha\nscale/accumulate adj_C contributions by beta where needed
  NativeAdj-->>AD: adj_A, adj_B, adj_C (updated)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.11% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The PR title "Add alpha and beta scalings to tile_matmul" accurately and directly describes the primary change across the entire changeset. The main objective is to introduce alpha and beta scaling parameters to the tile_matmul function API, changing its computation from out += a*b to out = alpha * a*b + beta * out. This is reflected consistently across all modified files, including the public API signatures in warp/init.pyi and warp/_src/builtins.py, the C++ implementation in warp/native/tile.h, documentation updates, and test/example modifications. The title is concise, specific, and non-misleading, clearly conveying the core enhancement without ambiguity or unnecessary noise.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 20ca408 and 988737a.

📒 Files selected for processing (9)
  • CHANGELOG.md (1 hunks)
  • docs/modules/functions.rst (4 hunks)
  • warp/__init__.pyi (3 hunks)
  • warp/_src/builtins.py (7 hunks)
  • warp/_src/codegen.py (1 hunks)
  • warp/examples/tile/example_tile_block_cholesky.py (4 hunks)
  • warp/native/tile.h (7 hunks)
  • warp/tests/tile/test_tile_cholesky.py (4 hunks)
  • warp/tests/tile/test_tile_mathdx.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • CHANGELOG.md
  • warp/_src/codegen.py
🧰 Additional context used
🧬 Code graph analysis (6)
warp/examples/tile/example_tile_block_cholesky.py (1)
warp/__init__.pyi (2)
  • tile_matmul (5393-5418)
  • tile_matmul (5421-5441)
warp/tests/tile/test_tile_cholesky.py (1)
warp/__init__.pyi (2)
  • tile_matmul (5393-5418)
  • tile_matmul (5421-5441)
warp/tests/tile/test_tile_mathdx.py (2)
warp/__init__.pyi (3)
  • tile_load (2570-2585)
  • tile_matmul (5393-5418)
  • tile_matmul (5421-5441)
warp/tests/unittest_utils.py (1)
  • assert_np_equal (241-247)
warp/_src/builtins.py (2)
warp/_src/context.py (3)
  • value_func (67-72)
  • value_func (1332-1333)
  • add_builtin (1254-1493)
warp/_src/types.py (2)
  • tile (4023-4112)
  • dtype (4732-4738)
warp/native/tile.h (1)
warp/__init__.pyi (2)
  • tile_matmul (5393-5418)
  • tile_matmul (5421-5441)
warp/__init__.pyi (2)
warp/native/tile.h (3)
  • Tile (2983-2983)
  • Tile (2997-2997)
  • Tile (3011-3011)
warp/_src/types.py (1)
  • Tile (166-167)
🪛 Ruff (0.14.1)
warp/tests/tile/test_tile_mathdx.py

69-69: assert_np_equal may be undefined, or defined from star imports

(F405)


75-75: assert_np_equal may be undefined, or defined from star imports

(F405)


76-76: assert_np_equal may be undefined, or defined from star imports

(F405)


77-77: assert_np_equal may be undefined, or defined from star imports

(F405)

warp/_src/builtins.py

9598-9598: Float may be undefined, or defined from star imports

(F405)


9599-9599: Float may be undefined, or defined from star imports

(F405)


9629-9629: tile may be undefined, or defined from star imports

(F405)


9629-9629: Float may be undefined, or defined from star imports

(F405)


9629-9629: Tuple may be undefined, or defined from star imports

(F405)


9630-9630: tile may be undefined, or defined from star imports

(F405)


9630-9630: Float may be undefined, or defined from star imports

(F405)


9630-9630: Tuple may be undefined, or defined from star imports

(F405)


9631-9631: Float may be undefined, or defined from star imports

(F405)

🔇 Additional comments (13)
docs/modules/functions.rst (1)

2819-2871: Documentation correctly reflects the new API.

The tile_matmul documentation has been properly updated to include the alpha and beta parameters, with correct formulas and appropriate notes about adjoint support. The past formatting issues with extra backticks have been resolved.

warp/tests/tile/test_tile_mathdx.py (4)

43-44: LGTM! Correct usage of the new alpha/beta API.

Loading the existing C content (line 43) is necessary for the beta scaling to work correctly. The in-place operation out = alpha * a*b + beta * out with alpha=0.5 and beta=-1.3 demonstrates the memory efficiency benefit mentioned in the PR objectives.


53-53: Good test setup improvement.

Initializing C with random values instead of zeros provides better test coverage for the beta scaling behavior.


69-69: Forward pass verification is correct.

The assertion correctly verifies the new semantics: out = 0.5 * (A @ B) + (-1.3) * C.


75-77: Gradient checks correctly account for alpha/beta scaling.

The gradient verifications are mathematically correct:

  • Gradients w.r.t. A and B are scaled by alpha (0.5)
  • Gradient w.r.t. C correctly accounts for the in-place operation: (1 + beta) * adj_C = -0.3 * adj_C

As discussed in the past review, the C gradient behavior is technically correct but reflects the non-intuitive nature of in-place operations with autodiff.

warp/examples/tile/example_tile_block_cholesky.py (4)

81-81: Excellent demonstration of the in-place accumulation benefit.

Using alpha=-1.0 with in-place matmul elegantly replaces the previous pattern of separate multiplication and subtraction. This reduces both the number of operations and shared memory usage, which directly addresses the PR's motivation for improved kernel occupancy.


110-110: Consistent usage of in-place subtraction.

The pattern is correctly applied here as well: wp.tile_matmul(L_tile, L_T_tile, A_ik_tile, alpha=-1.0) eliminates the need for an intermediate temporary.


151-151: Forward substitution correctly updated.

The in-place matmul with alpha=-1.0 correctly implements the subtraction needed in the forward substitution step.


166-166: Backward substitution correctly updated.

The in-place matmul with alpha=-1.0 correctly implements the subtraction needed in the backward substitution step, completing the consistent usage pattern throughout the solver.

warp/tests/tile/test_tile_cholesky.py (1)

373-373: LGTM! Excellent use of fused accumulation.

The refactoring from explicit subtraction (A -= B) to in-place scaled accumulation (tile_matmul(A, B, out, alpha=-1.0)) is correct and demonstrates the benefit of the new API. By fusing the matrix multiplication and subtraction into a single operation, this eliminates intermediate temporaries and reduces memory traffic, which aligns perfectly with the PR objectives.

Each change is semantically equivalent to the original code:

  • Line 373: A_kk_tile = -1.0 * L_block * L_block_T + 1.0 * A_kk_tileA_kk_tile -= L_block * L_block_T
  • Line 388: Similar pattern for off-diagonal updates
  • Lines 415, 429: Similar pattern for forward/backward substitution

Also applies to: 388-388, 415-415, 429-429

warp/native/tile.h (3)

212-214: LGTM: Standard remove_reference implementation.

The remove_reference utility templates are correctly implemented following the standard pattern.


3414-3444: LGTM: Forward pass correctly implements alpha/beta scaling.

The implementation properly:

  • Converts alpha and beta to the tile element type T
  • Passes them to both the scalar and CUDA code paths
  • Maintains compatibility with the existing function pointer dispatch mechanism

The type conversion on lines 3432-3433 assumes Alpha and Beta are convertible to T, which is reasonable for numeric scaling parameters.


3448-3479: LGTM: Backward pass correctly implements gradient flow with beta scaling.

The implementation properly handles the backward pass:

  • Sets beta_A = 1.0 and beta_B = 1.0 for gradient accumulation (lines 3457, 3459)
  • Scales adj_C.grad by beta when beta != 1.0 (lines 3472-3476) to correctly apply the chain rule
  • The adjoints for alpha and beta parameters are placeholders (as documented in the Python API: "computing the adjoints of alpha and beta are not yet supported")

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@RSchwan RSchwan changed the title add alpha and beta scalings to tile_matmul Add alpha and beta scalings to tile_matmul Oct 11, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
warp/native/tile.h (1)

2977-3000: Don't take alpha/beta by non-const reference.

Requiring non-const references here breaks all existing call sites that pass temporaries or const scalars (e.g. T(1.0f), T(0.0f)), which previously compiled because the parameters were taken by value. This change will surface as hard compile errors for downstream kernels still using rvalues. Keep the new beta logic but accept these parameters as const references (or by value) so the API remains source-compatible.

-inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, T& alpha, T& beta)
+inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, const T& alpha, const T& beta)
🧹 Nitpick comments (2)
docs/modules/functions.rst (2)

2681-2682: Clarify default parameter behavior.

The parameter descriptions state "(default 1.0)" for both alpha and beta, but the function signature shows them as required parameters without default values. This creates ambiguity:

  • If these parameters have default values and can be omitted, the documentation should explicitly state this or the signature should show the defaults.
  • If they are required parameters, remove the "(default 1.0)" text to avoid confusion.

Based on the usage examples in the codebase (e.g., wp.tile_matmul(L_block, L_block_T, A_kk_tile, alpha=-1.0) where beta is not specified), it appears beta does have a default value of 1.0. Consider clarifying this as "optional, defaults to 1.0" rather than just "default 1.0".


2705-2706: Inconsistent adjoint support note.

This overload only has the alpha parameter (no beta), but the note states "computing the adjoints of alpha and beta are not yet supported." This is inconsistent and could confuse readers about which overload they're viewing.

Consider revising to: "Note that computing the adjoint of alpha is not yet supported." to match this overload's actual parameters.

📜 Review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 86373e3 and 0638254.

📒 Files selected for processing (8)
  • docs/modules/functions.rst (4 hunks)
  • warp/__init__.pyi (3 hunks)
  • warp/_src/builtins.py (7 hunks)
  • warp/_src/codegen.py (1 hunks)
  • warp/examples/tile/example_tile_block_cholesky.py (4 hunks)
  • warp/native/tile.h (7 hunks)
  • warp/tests/tile/test_tile_cholesky.py (4 hunks)
  • warp/tests/tile/test_tile_mathdx.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
warp/tests/tile/test_tile_cholesky.py (1)
warp/__init__.pyi (2)
  • tile_matmul (5265-5290)
  • tile_matmul (5293-5313)
warp/examples/tile/example_tile_block_cholesky.py (2)
warp/native/tile.h (1)
  • wp (199-266)
warp/__init__.pyi (2)
  • tile_matmul (5265-5290)
  • tile_matmul (5293-5313)
warp/tests/tile/test_tile_mathdx.py (2)
warp/native/tile.h (10)
  • tile_load (1961-1964)
  • T (406-413)
  • T (415-422)
  • T (431-438)
  • T (440-447)
  • T (552-556)
  • T (558-562)
  • T (866-872)
  • T (874-881)
  • T (883-890)
warp/__init__.pyi (3)
  • tile_load (2540-2555)
  • tile_matmul (5265-5290)
  • tile_matmul (5293-5313)
warp/_src/builtins.py (2)
warp/_src/context.py (3)
  • value_func (68-73)
  • value_func (1333-1334)
  • add_builtin (1255-1494)
warp/_src/types.py (2)
  • tile (3992-4081)
  • dtype (4674-4680)
warp/native/tile.h (1)
warp/__init__.pyi (3)
  • tile_matmul (5265-5290)
  • tile_matmul (5293-5313)
  • tile_transpose (2872-2881)
🪛 Ruff (0.13.3)
warp/tests/tile/test_tile_mathdx.py

69-69: assert_np_equal may be undefined, or defined from star imports

(F405)


75-75: assert_np_equal may be undefined, or defined from star imports

(F405)


76-76: assert_np_equal may be undefined, or defined from star imports

(F405)


77-77: assert_np_equal may be undefined, or defined from star imports

(F405)

warp/_src/builtins.py

9269-9269: Float may be undefined, or defined from star imports

(F405)


9270-9270: Float may be undefined, or defined from star imports

(F405)


9300-9300: tile may be undefined, or defined from star imports

(F405)


9300-9300: Float may be undefined, or defined from star imports

(F405)


9300-9300: Tuple may be undefined, or defined from star imports

(F405)


9301-9301: tile may be undefined, or defined from star imports

(F405)


9301-9301: Float may be undefined, or defined from star imports

(F405)


9301-9301: Tuple may be undefined, or defined from star imports

(F405)


9302-9302: Float may be undefined, or defined from star imports

(F405)

🔇 Additional comments (5)
warp/examples/tile/example_tile_block_cholesky.py (4)

77-81: LGTM! Efficient in-place update eliminates intermediate allocation.

The change from explicit subtraction to using alpha=-1.0 correctly implements A_kk_tile -= L_block @ L_block_T in-place. This eliminates the intermediate L_L_T_block allocation mentioned in the PR description, improving memory efficiency while maintaining correctness.

The implicit beta=1.0 default ensures proper accumulation semantics: out = -1.0 * a * b + 1.0 * out.


105-110: Consistent in-place update pattern.

The off-diagonal block update follows the same pattern as the diagonal block (line 81), correctly replacing explicit subtraction with alpha=-1.0 for in-place accumulation.


147-151: Correct forward substitution update.

The forward substitution step correctly uses alpha=-1.0 to perform in-place subtraction: rhs_tile -= L_block @ y_block.


161-166: Correct backward substitution update.

The backward substitution step correctly mirrors the forward substitution pattern, using alpha=-1.0 for in-place update: rhs_tile -= L_T_tile @ x_tile.

warp/tests/tile/test_tile_cholesky.py (1)

373-373: Test updates correctly mirror example code.

All four test updates correctly adopt the new tile_matmul API with alpha=-1.0 for in-place subtraction:

  • Line 373: Diagonal block update
  • Line 388: Off-diagonal block update
  • Line 415: Forward substitution
  • Line 429: Backward substitution

The changes match the patterns established in example_tile_block_cholesky.py and preserve the original test semantics. The existing test assertions will correctly validate the new implementation.

Also applies to: 388-388, 415-415, 429-429

Comment on lines +9151 to 9485
beta = 0.0 # for c = tile_matmul(a,b) case we want to overwrite c value
out = return_values[0]
else:
accumulate = 1 # for tile_matmul(a,b,c) case we want to add to c value
beta = arg_values["beta"]
out = arg_values["out"]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Wrap the implicit beta=0 in a Var.

When we take the return-value overload we drop into this branch, set beta = 0.0, and forward it to the MathDx path. The LTO dispatcher and codegen machinery expect every runtime argument to be a Var-like object (so that .type / .constant metadata is available), which is how the existing non-const arguments are handled. Passing a bare Python float here will blow up at codegen time (AttributeError when the serializer tries to inspect the argument). Please wrap the zero in a Var with the same scalar type as alpha before returning.

-        beta = 0.0  # for c = tile_matmul(a,b) case we want to overwrite c value
+        beta = Var(label=None, type=alpha.type, constant=0.0)  # ensure codegen sees a typed constant
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
beta = 0.0 # for c = tile_matmul(a,b) case we want to overwrite c value
out = return_values[0]
else:
accumulate = 1 # for tile_matmul(a,b,c) case we want to add to c value
beta = arg_values["beta"]
out = arg_values["out"]
beta = Var(label=None, type=alpha.type, constant=0.0) # ensure codegen sees a typed constant
out = return_values[0]
else:
beta = arg_values["beta"]
out = arg_values["out"]
🤖 Prompt for AI Agents
In warp/_src/builtins.py around lines 9151-9156, the branch that sets beta = 0.0
when using the return-value overload must not assign a bare Python float; create
a Var representing zero with the same scalar type/metadata as the existing alpha
Var (so it exposes .type and .constant) and assign that Var to beta before
returning; in short, wrap the implicit zero in a Var constructed with alpha's
scalar type/metadata so downstream LTO/codegen sees a Var-like runtime argument.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
docs/modules/functions.rst (1)

2667-2675: Document supported mixed-dtype matmul usage
Currently docs mandate uniform dtype (“same datatype”) but warp/tests/tile/test_tile_mathdx.py invokes tile_matmul on fp16, fp32, fp64 arrays. Update docs (docs/modules/functions.rst 2667–2675 & 2696–2704) to specify per-backend dtype combinations and promotion rules or align tests to use a single dtype.

warp/_src/builtins.py (1)

9264-9295: Allow complex tiles in the public signature; fix minor doc typo

  • Signatures use tile(dtype=Float, …) which excludes complex tiles (vec2{h,f,d}), yet docs and runtime checks allow them. Add overloads for tile(dtype=vector(length=2, dtype=Float), …) to make complex matmul callable.
  • Remove stray backticks in the docs.

Apply these small doc fixes:

@@
-    :param alpha: Scaling factor (default 1.0)``
-    :param beta: Accumulator factor (default 1.0)``
+    :param alpha: Scaling factor (default 1.0)
+    :param beta: Accumulator factor (default 1.0)
@@
-    :param alpha: Scaling factor (default 1.0)``
+    :param alpha: Scaling factor (default 1.0)

Add complex overloads mirroring the float ones (place right below the existing add_builtin blocks):

# Complex (vec2*) overload — out variant
add_builtin(
    "tile_matmul",
    input_types={
        "a": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int]),
        "b": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int]),
        "out": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int]),
        "alpha": Float,
        "beta": Float,
    },
    defaults={"alpha": 1.0, "beta": 1.0},
    value_func=tile_matmul_out_value_func,
    lto_dispatch_func=tile_matmul_lto_dispatch_func,
    variadic=False,
    doc="""Computes the matrix product and accumulates out = alpha * a*b + beta * out.

Supported datatypes are:
    * fp16, fp32, fp64 (real)
    * vec2h, vec2f, vec2d (complex)

Note that computing the adjoints of alpha and beta are not yet supported.

:param a: A tile with shape=(M, K)
:param b: A tile with shape=(K, N)
:param out: A tile with shape=(M, N)
:param alpha: Scaling factor (default 1.0)
:param beta: Accumulator factor (default 1.0)
""",
    group="Tile Primitives",
    export=False,
)

# Complex (vec2*) overload — return-value variant
add_builtin(
    "tile_matmul",
    input_types={
        "a": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int]),
        "b": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int]),
        "alpha": Float,
    },
    defaults={"alpha": 1.0},
    value_func=tile_matmul_value_func,
    lto_dispatch_func=tile_matmul_lto_dispatch_func,
    variadic=False,
    doc="""Computes the matrix product out = alpha * a*b.

Supported datatypes are:
    * fp16, fp32, fp64 (real)
    * vec2h, vec2f, vec2d (complex)

Note that computing the adjoints of alpha and beta are not yet supported.

:param a: A tile with shape=(M, K)
:param b: A tile with shape=(K, N)
:param alpha: Scaling factor (default 1.0)
:returns: A tile with shape=(M, N)
""",
    group="Tile Primitives",
    export=False,
)

If you prefer avoiding duplication, we can factor a helper to register both float and vec2 overloads programmatically.

Also applies to: 9298-9326

warp/native/tile.h (1)

3207-3232: Critical: use adj_ret (return-value adjoint) instead of adj_C in out-of-place adjoint overload

In the adj_tile_matmul overload for the out = wp.tile_matmul(a, b) form (warp/native/tile.h, lines ~3207-3232) the code uses adj_C.grad / adj_C.grad.ptr as the gradient source for dA/dB. It must use adj_ret.grad / adj_ret.grad.ptr (the return-value adjoint) — otherwise gradients will be incorrect/zero.

Apply this diff:

-    partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, alpha_A, beta_A);
-    partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, alpha_B, beta_B);
+    partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_ret.grad, Bt.data, adj_A.grad, alpha_A, beta_A);
+    partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_ret.grad, adj_B.grad, alpha_B, beta_B);
@@
-    fun_backward_A(&alpha_A, adj_C.grad.ptr, B.data.ptr, &beta_A, adj_A.grad.ptr);
-    fun_backward_B(&alpha_B, A.data.ptr, adj_C.grad.ptr, &beta_B, adj_B.grad.ptr);
+    fun_backward_A(&alpha_A, adj_ret.grad.ptr, B.data.ptr, &beta_A, adj_A.grad.ptr);
+    fun_backward_B(&alpha_B, A.data.ptr, adj_ret.grad.ptr, &beta_B, adj_B.grad.ptr);

Add/extend a test for the out-of-place form (alpha/beta arbitrary) verifying dA/dB against a NumPy/Torch reference.

🧹 Nitpick comments (5)
docs/modules/functions.rst (1)

2659-2660: Clarify defaults/overloads in signature.

Docs mention defaults (alpha/beta=1.0), and examples omit beta (relying on default). Consider stating explicitly that beta is optional and defaults to 1.0 when omitted in the out-accumulating overload. This avoids ambiguity between the 5‑arg signature and usage.

Also applies to: 2686-2691

warp/tests/tile/test_tile_mathdx.py (1)

69-77: Avoid star import to satisfy linters (F405).

Import test helpers explicitly to silence F405 and improve clarity:

- from warp.tests.unittest_utils import *
+ from warp.tests.unittest_utils import assert_np_equal, get_test_devices, get_cuda_test_devices, add_function_test
warp/native/tile.h (3)

214-216: Prefer standard type trait or add a shorthand alias

Re-implementing remove_reference works, but consider using std::remove_reference (when available) or at least adding a remove_reference_t alias for brevity across the file.


2977-3000: Make alpha/beta const refs and avoid double read of C(coord)

  • Pass alpha/beta as const T&; you don't modify them.
  • Cache C(coord) once to prevent double indexing/load.

Apply this diff:

-inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, T& alpha, T& beta)
+inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, const T& alpha, const T& beta)
@@
-        using TypeC = typename remove_reference<decltype(C(coord))>::type;
-        TypeC sum = TypeC(0);
+        using TypeC = typename remove_reference<decltype(C(coord))>::type;
+        TypeC sum = TypeC(0);
@@
-        C(coord) = alpha * sum + beta * C(coord);
+        const TypeC c_old = C(coord);
+        C(coord) = alpha * sum + beta * c_old;

3139-3164: LGTM: forward semantics match out = alpha(A@B) + beta*out*

The new API and dispatch look correct. Consider a future fast path for beta == 0 to skip reading out in CPU path (already handled inside scalar_matmul after caching old C).

📜 Review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 86373e3 and 9dc3e93.

📒 Files selected for processing (9)
  • CHANGELOG.md (1 hunks)
  • docs/modules/functions.rst (4 hunks)
  • warp/__init__.pyi (3 hunks)
  • warp/_src/builtins.py (7 hunks)
  • warp/_src/codegen.py (1 hunks)
  • warp/examples/tile/example_tile_block_cholesky.py (4 hunks)
  • warp/native/tile.h (7 hunks)
  • warp/tests/tile/test_tile_cholesky.py (4 hunks)
  • warp/tests/tile/test_tile_mathdx.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
warp/examples/tile/example_tile_block_cholesky.py (1)
warp/__init__.pyi (2)
  • tile_matmul (5265-5290)
  • tile_matmul (5293-5313)
warp/tests/tile/test_tile_cholesky.py (1)
warp/__init__.pyi (2)
  • tile_matmul (5265-5290)
  • tile_matmul (5293-5313)
warp/native/tile.h (1)
warp/__init__.pyi (3)
  • tile_matmul (5265-5290)
  • tile_matmul (5293-5313)
  • tile_transpose (2872-2881)
warp/tests/tile/test_tile_mathdx.py (3)
warp/native/tile.h (10)
  • tile_load (1961-1964)
  • T (406-413)
  • T (415-422)
  • T (431-438)
  • T (440-447)
  • T (552-556)
  • T (558-562)
  • T (866-872)
  • T (874-881)
  • T (883-890)
warp/__init__.pyi (3)
  • tile_load (2540-2555)
  • tile_matmul (5265-5290)
  • tile_matmul (5293-5313)
warp/tests/unittest_utils.py (1)
  • assert_np_equal (241-247)
warp/_src/builtins.py (2)
warp/_src/context.py (3)
  • value_func (68-73)
  • value_func (1333-1334)
  • add_builtin (1255-1494)
warp/_src/types.py (2)
  • tile (3992-4081)
  • dtype (4674-4680)
🪛 Ruff (0.13.3)
warp/tests/tile/test_tile_mathdx.py

69-69: assert_np_equal may be undefined, or defined from star imports

(F405)


75-75: assert_np_equal may be undefined, or defined from star imports

(F405)


76-76: assert_np_equal may be undefined, or defined from star imports

(F405)


77-77: assert_np_equal may be undefined, or defined from star imports

(F405)

warp/_src/builtins.py

9269-9269: Float may be undefined, or defined from star imports

(F405)


9270-9270: Float may be undefined, or defined from star imports

(F405)


9300-9300: tile may be undefined, or defined from star imports

(F405)


9300-9300: Float may be undefined, or defined from star imports

(F405)


9300-9300: Tuple may be undefined, or defined from star imports

(F405)


9301-9301: tile may be undefined, or defined from star imports

(F405)


9301-9301: Float may be undefined, or defined from star imports

(F405)


9301-9301: Tuple may be undefined, or defined from star imports

(F405)


9302-9302: Float may be undefined, or defined from star imports

(F405)

🔇 Additional comments (6)
warp/_src/codegen.py (1)

1258-1261: Float literals now treated as SSA constants — good catch.

Mirrors the int path and prevents unlabeled python floats from leaking past register_var. Nicely done.

warp/examples/tile/example_tile_block_cholesky.py (1)

81-81: Good fusion: use alpha=-1.0 to fold subtraction into matmul.

Removes intermediates and shared-memory traffic; aligns with new API.

Also applies to: 110-110, 151-151, 166-166

warp/tests/tile/test_tile_mathdx.py (1)

43-45: Semantics and gradients look correct.

Forward and backward checks match out = α·A@B + β·C and adjoints (excluding α/β).

Also applies to: 69-78

warp/tests/tile/test_tile_cholesky.py (1)

373-374: LGTM: fused updates via alpha=-1.0.

Cleaner, fewer temporaries, consistent with new API.

Also applies to: 388-389, 415-416, 429-430

warp/_src/builtins.py (1)

9148-9186: Use a typed Var for beta; verify gradient scaling by alpha

  • In the return-value path, beta is set to a raw float (0.0). The dispatch pipeline expects scalar runtime args as Var. Replace with a constant Var of the correct dtype:
@@
-    if len(return_values) > 0:
-        beta = 0.0  # for c = tile_matmul(a,b) case we want to overwrite c value
+    if len(return_values) > 0:
+        beta = Var(None, type=out.type.dtype, constant=0.0)
  • I couldn’t locate the backward kernels in this file. Please verify that the gradient implementation applies the alpha factor (i.e. dA = alpha * dC @ Bᵀ, dB = alpha * Aᵀ @ dC) when alpha != 1. If not, thread alpha through the adjoint or scale dC accordingly.
warp/native/tile.h (1)

3173-3204: LGTM: in-place adjoint scales adj_C by beta after computing dA/dB

This ordering avoids contaminating dA/dB with beta; correct for reverse-mode. No action needed.

Please ensure tests cover the in-place variant with beta ≠ {0,1} for both real and complex dtypes.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (5)
docs/modules/functions.rst (2)

2681-2682: Use consistent “default: …” formatting in params.

Elsewhere in this doc you use “default: …”. Align alpha/beta to match.

-    :param alpha: Scaling factor (default 1.0)
-    :param beta: Accumulator factor (default 1.0)
+    :param alpha: Scaling factor (default: 1.0)
+    :param beta: Accumulator factor (default: 1.0)

2686-2706: Minor clarity: phrase return semantics; keep default formatting consistent.

  • Suggest phrasing the semantics as a return value to avoid “out = …” when no out param is present.
  • Keep “default: …” formatting consistent.
-    Computes the matrix product ``out = alpha * a*b``.
+    Computes the matrix product and returns ``alpha * a * b``.
-    :param alpha: Scaling factor (default 1.0)
+    :param alpha: Scaling factor (default: 1.0)
warp/__init__.pyi (2)

5286-5287: Fix backtick formatting in parameter documentation.

The backticks are malformed (stray closing backticks). Additionally, the docstrings claim "(default 1.0)" but the function signature shows these as required parameters without defaults. Either add defaults to the signature (alpha: Float = 1.0, beta: Float = 1.0) or remove the "(default 1.0)" text from the docstrings.

Apply this diff to fix the formatting and align with the signature:

-    :param alpha: Scaling factor (default 1.0)``
-    :param beta: Accumulator factor (default 1.0)``
+    :param alpha: Scaling factor
+    :param beta: Accumulator factor

Or if defaults are intended:

 def tile_matmul(
     a: Tile[Float, Tuple[int, int]],
     b: Tile[Float, Tuple[int, int]],
     out: Tile[Float, Tuple[int, int]],
-    alpha: Float,
-    beta: Float,
+    alpha: Float = 1.0,
+    beta: Float = 1.0,
 ):

5309-5309: Fix backtick formatting and clarify default parameter value.

The backtick is malformed (stray closing backtick). Additionally, the docstring claims "(default 1.0)" but the function signature shows alpha as a required parameter. Either add a default to the signature (alpha: Float = 1.0) or remove the "(default 1.0)" text from the docstring.

Apply this diff to fix the formatting and align with the signature:

-    :param alpha: Scaling factor (default 1.0)`
+    :param alpha: Scaling factor

Or if a default is intended:

 def tile_matmul(
-    a: Tile[Float, Tuple[int, int]], b: Tile[Float, Tuple[int, int]], alpha: Float
+    a: Tile[Float, Tuple[int, int]], b: Tile[Float, Tuple[int, int]], alpha: Float = 1.0
 ) -> Tile[Float, Tuple[int, int]]:
warp/_src/builtins.py (1)

9148-9156: Wrap implicit beta=0.0 in a Var (codegen will crash on bare float).

LTO/codegen expects runtime args as Var-like. Assigning a Python float to beta in the return‑value overload will fail at serialization. Construct a typed Var using alpha.type.

Apply:

-        beta = 0.0  # for c = tile_matmul(a,b) case we want to overwrite c value
+        beta = Var(label=None, type=alpha.type, constant=0.0)  # ensure typed constant for codegen
🧹 Nitpick comments (3)
docs/modules/functions.rst (1)

2659-2677: API/semantics LGTM; clarify alpha/beta type for complex tiles.

Looks good. One request: for complex tiles (vec2h/f/d), are alpha/beta allowed to be complex or real-only? If real-only, please state that explicitly; if complex is supported, consider documenting that alpha/beta match the tile dtype.

Optional nit: “Computing the adjoints of alpha and beta are not yet supported.” → use singular “is” or rephrase for clarity.

-    Note that computing the adjoints of alpha and beta are not yet supported.
+    Note that computing the adjoints of alpha and beta is not yet supported.
warp/__init__.pyi (1)

5305-5305: Use singular "adjoint" for a single parameter.

Since this overload only has the alpha parameter, the note should use the singular form "adjoint" instead of plural "adjoints".

Apply this diff:

-    Note that computing the adjoints of alpha is not yet supported.
+    Note that computing the adjoint of alpha is not yet supported.
warp/_src/builtins.py (1)

9269-9273: Signatures restrict alpha/beta to Float; docs say vec2(Complex) tiles supported.

If A/B/out are complex (vec2h/f/d), Float alpha/beta block complex scaling and may mismatch backend expectations. Consider adding overloads with alpha/beta: vector(length=2, dtype=Float), or clarify docs if complex scaling isn’t supported.

  • Option A: Add parallel add_builtin overloads for complex tiles with alpha/beta as vector(length=2, dtype=Float).
  • Option B: If only real scaling is supported for complex tiles, update docs to state alpha/beta must be real.

Can you confirm intended behavior and backend support for complex alpha/beta? Based on learnings

Also applies to: 9299-9305

📜 Review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9dc3e93 and b058026.

📒 Files selected for processing (3)
  • docs/modules/functions.rst (4 hunks)
  • warp/__init__.pyi (3 hunks)
  • warp/_src/builtins.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
warp/_src/builtins.py (2)
warp/_src/context.py (3)
  • value_func (68-73)
  • value_func (1333-1334)
  • add_builtin (1255-1494)
warp/_src/types.py (2)
  • tile (3992-4081)
  • dtype (4674-4680)
🪛 Ruff (0.13.3)
warp/_src/builtins.py

9269-9269: Float may be undefined, or defined from star imports

(F405)


9270-9270: Float may be undefined, or defined from star imports

(F405)


9300-9300: tile may be undefined, or defined from star imports

(F405)


9300-9300: Float may be undefined, or defined from star imports

(F405)


9300-9300: Tuple may be undefined, or defined from star imports

(F405)


9301-9301: tile may be undefined, or defined from star imports

(F405)


9301-9301: Float may be undefined, or defined from star imports

(F405)


9301-9301: Tuple may be undefined, or defined from star imports

(F405)


9302-9302: Float may be undefined, or defined from star imports

(F405)

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enhances the wp.tile_matmul() function by adding alpha and beta scaling parameters, implementing the formula out = alpha * (A @ B) + beta * out to match the cuBLASDx API. The enhancement enables fused matrix multiplication with scaling and accumulation operations, reducing memory usage and improving performance by eliminating the need for temporary intermediate results.

The changes include two function overloads: one for in-place accumulation with both alpha and beta parameters (defaulting to 1.0), and another that returns a new tile with only an alpha parameter. The API maintains backward compatibility while enabling more efficient operations, as demonstrated in the updated Cholesky decomposition example where A -= L * L^T can now be expressed as wp.tile_matmul(L, L_T, A, alpha=-1.0).

Important Files Changed

Changed Files
Filename Score Overview
warp/native/tile.h 3/5 Enhanced C++ implementation to support alpha/beta scaling in tile matrix multiplication
warp/_src/builtins.py 4/5 Updated dispatch function to handle alpha/beta parameters and distinguish between overloads
warp/init.pyi 5/5 Auto-generated type stubs updated with new function signatures and documentation
docs/modules/functions.rst 5/5 Documentation updated to describe new parameters and overload variants
warp/_src/codegen.py 5/5 Added support for float constants in variable registration for code generation
warp/tests/tile/test_tile_mathdx.py 4/5 Updated tests to validate alpha/beta scaling functionality and gradient computations
warp/tests/tile/test_tile_cholesky.py 5/5 Modified to use new alpha parameter for in-place matrix operations
warp/examples/tile/example_tile_block_cholesky.py 5/5 Optimized example using fused operations to reduce memory usage
CHANGELOG.md 5/5 Added changelog entry documenting the new feature

Confidence score: 3/5

  • This PR introduces significant enhancements but has some potential issues in the backward pass implementation that need attention
  • Score lowered due to concerns in tile.h where gradient calculations may not properly handle non-unit beta values, and potential type conversion issues between template parameters
  • Pay close attention to warp/native/tile.h, particularly the backward pass logic for gradient computation with beta scaling

Additional Comments (4)

  1. warp/tests/tile/test_tile_mathdx.py, line 77 (link)

    logic: C gradient calculation appears incorrect - should be -1.3 * adj_C not adj_C - 1.3 * adj_C

  2. warp/native/tile.h, line 3197-3201 (link)

    logic: This multiplication by beta could accumulate incorrectly if the adjoint function is called multiple times on the same tile, as it modifies adj_C.grad(i) in place.

  3. warp/native/tile.h, line 3157-3158 (link)

    style: Type conversion from template parameters to tile type could cause precision loss. Consider using static_cast with appropriate bounds checking for safety.

  4. warp/native/tile.h, line 3181-3184 (link)

    logic: The backward pass assumes beta=1.0 for gradient scaling, which may not produce correct gradients when the forward pass used beta != 1.0.

9 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@RSchwan RSchwan force-pushed the rschwan/matmul_scaling branch from b058026 to 20ca408 Compare October 22, 2025 16:13
Signed-off-by: Roland Schwan <roland.schwan@mikrounix.com>
@RSchwan RSchwan force-pushed the rschwan/matmul_scaling branch from 20ca408 to 988737a Compare October 22, 2025 16:22
@RSchwan
Copy link
Contributor Author

RSchwan commented Oct 22, 2025

Rebased and squashed the commits

@shi-eric shi-eric merged commit fae20c1 into NVIDIA:main Oct 23, 2025
16 checks passed
@shi-eric shi-eric added this to the 1.11.0 milestone Oct 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants