-
Notifications
You must be signed in to change notification settings - Fork 409
Add alpha and beta scalings to tile_matmul #1023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughAdd explicit scalar parameters Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Kernel
participant Builtins as builtins.tile_matmul
participant LTO as LTO Dispatch
participant Backend as CPU/MathDx/CUDA
participant Native as native::tile_matmul
Kernel->>Builtins: tile_matmul(a, b, out, alpha?, beta?)
Builtins->>LTO: args (a, b, out, alpha, beta)
Note right of Builtins: overwrite path -> beta = 0.0\notherwise use provided beta
LTO->>Backend: call(a, b, out, alpha, beta)
Backend->>Native: tile_matmul(..., alpha, beta)
Native->>Native: C = alpha * (A @ B) + beta * C
Native-->>Backend: out
Backend-->>Kernel: out
rect rgba(230,245,255,0.5)
note over Builtins,Native: alpha/beta threaded through all execution and adjoint paths
end
sequenceDiagram
autonumber
participant AD as Autodiff Engine
participant NativeAdj as native::adj_tile_matmul
AD->>NativeAdj: backprop(adj_C, alpha, beta)
Note right of NativeAdj: scale A/B grads by alpha\nscale/accumulate adj_C contributions by beta where needed
NativeAdj-->>AD: adj_A, adj_B, adj_C (updated)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: Path: .coderabbit.yml Review profile: CHILL Plan: Pro 📒 Files selected for processing (9)
🚧 Files skipped from review as they are similar to previous changes (2)
🧰 Additional context used🧬 Code graph analysis (6)warp/examples/tile/example_tile_block_cholesky.py (1)
warp/tests/tile/test_tile_cholesky.py (1)
warp/tests/tile/test_tile_mathdx.py (2)
warp/_src/builtins.py (2)
warp/native/tile.h (1)
warp/__init__.pyi (2)
🪛 Ruff (0.14.1)warp/tests/tile/test_tile_mathdx.py69-69: (F405) 75-75: (F405) 76-76: (F405) 77-77: (F405) warp/_src/builtins.py9598-9598: (F405) 9599-9599: (F405) 9629-9629: (F405) 9629-9629: (F405) 9629-9629: (F405) 9630-9630: (F405) 9630-9630: (F405) 9630-9630: (F405) 9631-9631: (F405) 🔇 Additional comments (13)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
warp/native/tile.h (1)
2977-3000: Don't take alpha/beta by non-const reference.Requiring non-const references here breaks all existing call sites that pass temporaries or const scalars (e.g.
T(1.0f),T(0.0f)), which previously compiled because the parameters were taken by value. This change will surface as hard compile errors for downstream kernels still using rvalues. Keep the new beta logic but accept these parameters as const references (or by value) so the API remains source-compatible.-inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, T& alpha, T& beta) +inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, const T& alpha, const T& beta)
🧹 Nitpick comments (2)
docs/modules/functions.rst (2)
2681-2682: Clarify default parameter behavior.The parameter descriptions state "(default 1.0)" for both
alphaandbeta, but the function signature shows them as required parameters without default values. This creates ambiguity:
- If these parameters have default values and can be omitted, the documentation should explicitly state this or the signature should show the defaults.
- If they are required parameters, remove the "(default 1.0)" text to avoid confusion.
Based on the usage examples in the codebase (e.g.,
wp.tile_matmul(L_block, L_block_T, A_kk_tile, alpha=-1.0)wherebetais not specified), it appearsbetadoes have a default value of 1.0. Consider clarifying this as "optional, defaults to 1.0" rather than just "default 1.0".
2705-2706: Inconsistent adjoint support note.This overload only has the
alphaparameter (nobeta), but the note states "computing the adjoints of alpha and beta are not yet supported." This is inconsistent and could confuse readers about which overload they're viewing.Consider revising to: "Note that computing the adjoint of alpha is not yet supported." to match this overload's actual parameters.
📜 Review details
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
docs/modules/functions.rst(4 hunks)warp/__init__.pyi(3 hunks)warp/_src/builtins.py(7 hunks)warp/_src/codegen.py(1 hunks)warp/examples/tile/example_tile_block_cholesky.py(4 hunks)warp/native/tile.h(7 hunks)warp/tests/tile/test_tile_cholesky.py(4 hunks)warp/tests/tile/test_tile_mathdx.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
warp/tests/tile/test_tile_cholesky.py (1)
warp/__init__.pyi (2)
tile_matmul(5265-5290)tile_matmul(5293-5313)
warp/examples/tile/example_tile_block_cholesky.py (2)
warp/native/tile.h (1)
wp(199-266)warp/__init__.pyi (2)
tile_matmul(5265-5290)tile_matmul(5293-5313)
warp/tests/tile/test_tile_mathdx.py (2)
warp/native/tile.h (10)
tile_load(1961-1964)T(406-413)T(415-422)T(431-438)T(440-447)T(552-556)T(558-562)T(866-872)T(874-881)T(883-890)warp/__init__.pyi (3)
tile_load(2540-2555)tile_matmul(5265-5290)tile_matmul(5293-5313)
warp/_src/builtins.py (2)
warp/_src/context.py (3)
value_func(68-73)value_func(1333-1334)add_builtin(1255-1494)warp/_src/types.py (2)
tile(3992-4081)dtype(4674-4680)
warp/native/tile.h (1)
warp/__init__.pyi (3)
tile_matmul(5265-5290)tile_matmul(5293-5313)tile_transpose(2872-2881)
🪛 Ruff (0.13.3)
warp/tests/tile/test_tile_mathdx.py
69-69: assert_np_equal may be undefined, or defined from star imports
(F405)
75-75: assert_np_equal may be undefined, or defined from star imports
(F405)
76-76: assert_np_equal may be undefined, or defined from star imports
(F405)
77-77: assert_np_equal may be undefined, or defined from star imports
(F405)
warp/_src/builtins.py
9269-9269: Float may be undefined, or defined from star imports
(F405)
9270-9270: Float may be undefined, or defined from star imports
(F405)
9300-9300: tile may be undefined, or defined from star imports
(F405)
9300-9300: Float may be undefined, or defined from star imports
(F405)
9300-9300: Tuple may be undefined, or defined from star imports
(F405)
9301-9301: tile may be undefined, or defined from star imports
(F405)
9301-9301: Float may be undefined, or defined from star imports
(F405)
9301-9301: Tuple may be undefined, or defined from star imports
(F405)
9302-9302: Float may be undefined, or defined from star imports
(F405)
🔇 Additional comments (5)
warp/examples/tile/example_tile_block_cholesky.py (4)
77-81: LGTM! Efficient in-place update eliminates intermediate allocation.The change from explicit subtraction to using
alpha=-1.0correctly implementsA_kk_tile -= L_block @ L_block_Tin-place. This eliminates the intermediateL_L_T_blockallocation mentioned in the PR description, improving memory efficiency while maintaining correctness.The implicit
beta=1.0default ensures proper accumulation semantics:out = -1.0 * a * b + 1.0 * out.
105-110: Consistent in-place update pattern.The off-diagonal block update follows the same pattern as the diagonal block (line 81), correctly replacing explicit subtraction with
alpha=-1.0for in-place accumulation.
147-151: Correct forward substitution update.The forward substitution step correctly uses
alpha=-1.0to perform in-place subtraction:rhs_tile -= L_block @ y_block.
161-166: Correct backward substitution update.The backward substitution step correctly mirrors the forward substitution pattern, using
alpha=-1.0for in-place update:rhs_tile -= L_T_tile @ x_tile.warp/tests/tile/test_tile_cholesky.py (1)
373-373: Test updates correctly mirror example code.All four test updates correctly adopt the new
tile_matmulAPI withalpha=-1.0for in-place subtraction:
- Line 373: Diagonal block update
- Line 388: Off-diagonal block update
- Line 415: Forward substitution
- Line 429: Backward substitution
The changes match the patterns established in
example_tile_block_cholesky.pyand preserve the original test semantics. The existing test assertions will correctly validate the new implementation.Also applies to: 388-388, 415-415, 429-429
| beta = 0.0 # for c = tile_matmul(a,b) case we want to overwrite c value | ||
| out = return_values[0] | ||
| else: | ||
| accumulate = 1 # for tile_matmul(a,b,c) case we want to add to c value | ||
| beta = arg_values["beta"] | ||
| out = arg_values["out"] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrap the implicit beta=0 in a Var.
When we take the return-value overload we drop into this branch, set beta = 0.0, and forward it to the MathDx path. The LTO dispatcher and codegen machinery expect every runtime argument to be a Var-like object (so that .type / .constant metadata is available), which is how the existing non-const arguments are handled. Passing a bare Python float here will blow up at codegen time (AttributeError when the serializer tries to inspect the argument). Please wrap the zero in a Var with the same scalar type as alpha before returning.
- beta = 0.0 # for c = tile_matmul(a,b) case we want to overwrite c value
+ beta = Var(label=None, type=alpha.type, constant=0.0) # ensure codegen sees a typed constant📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| beta = 0.0 # for c = tile_matmul(a,b) case we want to overwrite c value | |
| out = return_values[0] | |
| else: | |
| accumulate = 1 # for tile_matmul(a,b,c) case we want to add to c value | |
| beta = arg_values["beta"] | |
| out = arg_values["out"] | |
| beta = Var(label=None, type=alpha.type, constant=0.0) # ensure codegen sees a typed constant | |
| out = return_values[0] | |
| else: | |
| beta = arg_values["beta"] | |
| out = arg_values["out"] |
🤖 Prompt for AI Agents
In warp/_src/builtins.py around lines 9151-9156, the branch that sets beta = 0.0
when using the return-value overload must not assign a bare Python float; create
a Var representing zero with the same scalar type/metadata as the existing alpha
Var (so it exposes .type and .constant) and assign that Var to beta before
returning; in short, wrap the implicit zero in a Var constructed with alpha's
scalar type/metadata so downstream LTO/codegen sees a Var-like runtime argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
docs/modules/functions.rst (1)
2667-2675: Document supported mixed-dtype matmul usage
Currently docs mandate uniform dtype (“same datatype”) but warp/tests/tile/test_tile_mathdx.py invokes tile_matmul on fp16, fp32, fp64 arrays. Update docs (docs/modules/functions.rst 2667–2675 & 2696–2704) to specify per-backend dtype combinations and promotion rules or align tests to use a single dtype.warp/_src/builtins.py (1)
9264-9295: Allow complex tiles in the public signature; fix minor doc typo
- Signatures use tile(dtype=Float, …) which excludes complex tiles (vec2{h,f,d}), yet docs and runtime checks allow them. Add overloads for tile(dtype=vector(length=2, dtype=Float), …) to make complex matmul callable.
- Remove stray backticks in the docs.
Apply these small doc fixes:
@@ - :param alpha: Scaling factor (default 1.0)`` - :param beta: Accumulator factor (default 1.0)`` + :param alpha: Scaling factor (default 1.0) + :param beta: Accumulator factor (default 1.0) @@ - :param alpha: Scaling factor (default 1.0)`` + :param alpha: Scaling factor (default 1.0)Add complex overloads mirroring the float ones (place right below the existing add_builtin blocks):
# Complex (vec2*) overload — out variant add_builtin( "tile_matmul", input_types={ "a": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int]), "b": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int]), "out": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int]), "alpha": Float, "beta": Float, }, defaults={"alpha": 1.0, "beta": 1.0}, value_func=tile_matmul_out_value_func, lto_dispatch_func=tile_matmul_lto_dispatch_func, variadic=False, doc="""Computes the matrix product and accumulates out = alpha * a*b + beta * out. Supported datatypes are: * fp16, fp32, fp64 (real) * vec2h, vec2f, vec2d (complex) Note that computing the adjoints of alpha and beta are not yet supported. :param a: A tile with shape=(M, K) :param b: A tile with shape=(K, N) :param out: A tile with shape=(M, N) :param alpha: Scaling factor (default 1.0) :param beta: Accumulator factor (default 1.0) """, group="Tile Primitives", export=False, ) # Complex (vec2*) overload — return-value variant add_builtin( "tile_matmul", input_types={ "a": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int]), "b": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int]), "alpha": Float, }, defaults={"alpha": 1.0}, value_func=tile_matmul_value_func, lto_dispatch_func=tile_matmul_lto_dispatch_func, variadic=False, doc="""Computes the matrix product out = alpha * a*b. Supported datatypes are: * fp16, fp32, fp64 (real) * vec2h, vec2f, vec2d (complex) Note that computing the adjoints of alpha and beta are not yet supported. :param a: A tile with shape=(M, K) :param b: A tile with shape=(K, N) :param alpha: Scaling factor (default 1.0) :returns: A tile with shape=(M, N) """, group="Tile Primitives", export=False, )If you prefer avoiding duplication, we can factor a helper to register both float and vec2 overloads programmatically.
Also applies to: 9298-9326
warp/native/tile.h (1)
3207-3232: Critical: use adj_ret (return-value adjoint) instead of adj_C in out-of-place adjoint overloadIn the adj_tile_matmul overload for the out = wp.tile_matmul(a, b) form (warp/native/tile.h, lines ~3207-3232) the code uses adj_C.grad / adj_C.grad.ptr as the gradient source for dA/dB. It must use adj_ret.grad / adj_ret.grad.ptr (the return-value adjoint) — otherwise gradients will be incorrect/zero.
Apply this diff:
- partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, alpha_A, beta_A); - partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, alpha_B, beta_B); + partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_ret.grad, Bt.data, adj_A.grad, alpha_A, beta_A); + partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_ret.grad, adj_B.grad, alpha_B, beta_B); @@ - fun_backward_A(&alpha_A, adj_C.grad.ptr, B.data.ptr, &beta_A, adj_A.grad.ptr); - fun_backward_B(&alpha_B, A.data.ptr, adj_C.grad.ptr, &beta_B, adj_B.grad.ptr); + fun_backward_A(&alpha_A, adj_ret.grad.ptr, B.data.ptr, &beta_A, adj_A.grad.ptr); + fun_backward_B(&alpha_B, A.data.ptr, adj_ret.grad.ptr, &beta_B, adj_B.grad.ptr);Add/extend a test for the out-of-place form (alpha/beta arbitrary) verifying dA/dB against a NumPy/Torch reference.
🧹 Nitpick comments (5)
docs/modules/functions.rst (1)
2659-2660: Clarify defaults/overloads in signature.Docs mention defaults (alpha/beta=1.0), and examples omit beta (relying on default). Consider stating explicitly that beta is optional and defaults to 1.0 when omitted in the out-accumulating overload. This avoids ambiguity between the 5‑arg signature and usage.
Also applies to: 2686-2691
warp/tests/tile/test_tile_mathdx.py (1)
69-77: Avoid star import to satisfy linters (F405).Import test helpers explicitly to silence F405 and improve clarity:
- from warp.tests.unittest_utils import * + from warp.tests.unittest_utils import assert_np_equal, get_test_devices, get_cuda_test_devices, add_function_testwarp/native/tile.h (3)
214-216: Prefer standard type trait or add a shorthand aliasRe-implementing remove_reference works, but consider using std::remove_reference (when available) or at least adding a remove_reference_t alias for brevity across the file.
2977-3000: Make alpha/beta const refs and avoid double read of C(coord)
- Pass alpha/beta as const T&; you don't modify them.
- Cache C(coord) once to prevent double indexing/load.
Apply this diff:
-inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, T& alpha, T& beta) +inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, const T& alpha, const T& beta) @@ - using TypeC = typename remove_reference<decltype(C(coord))>::type; - TypeC sum = TypeC(0); + using TypeC = typename remove_reference<decltype(C(coord))>::type; + TypeC sum = TypeC(0); @@ - C(coord) = alpha * sum + beta * C(coord); + const TypeC c_old = C(coord); + C(coord) = alpha * sum + beta * c_old;
3139-3164: LGTM: forward semantics match out = alpha(A@B) + beta*out*The new API and dispatch look correct. Consider a future fast path for beta == 0 to skip reading out in CPU path (already handled inside scalar_matmul after caching old C).
📜 Review details
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
CHANGELOG.md(1 hunks)docs/modules/functions.rst(4 hunks)warp/__init__.pyi(3 hunks)warp/_src/builtins.py(7 hunks)warp/_src/codegen.py(1 hunks)warp/examples/tile/example_tile_block_cholesky.py(4 hunks)warp/native/tile.h(7 hunks)warp/tests/tile/test_tile_cholesky.py(4 hunks)warp/tests/tile/test_tile_mathdx.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
warp/examples/tile/example_tile_block_cholesky.py (1)
warp/__init__.pyi (2)
tile_matmul(5265-5290)tile_matmul(5293-5313)
warp/tests/tile/test_tile_cholesky.py (1)
warp/__init__.pyi (2)
tile_matmul(5265-5290)tile_matmul(5293-5313)
warp/native/tile.h (1)
warp/__init__.pyi (3)
tile_matmul(5265-5290)tile_matmul(5293-5313)tile_transpose(2872-2881)
warp/tests/tile/test_tile_mathdx.py (3)
warp/native/tile.h (10)
tile_load(1961-1964)T(406-413)T(415-422)T(431-438)T(440-447)T(552-556)T(558-562)T(866-872)T(874-881)T(883-890)warp/__init__.pyi (3)
tile_load(2540-2555)tile_matmul(5265-5290)tile_matmul(5293-5313)warp/tests/unittest_utils.py (1)
assert_np_equal(241-247)
warp/_src/builtins.py (2)
warp/_src/context.py (3)
value_func(68-73)value_func(1333-1334)add_builtin(1255-1494)warp/_src/types.py (2)
tile(3992-4081)dtype(4674-4680)
🪛 Ruff (0.13.3)
warp/tests/tile/test_tile_mathdx.py
69-69: assert_np_equal may be undefined, or defined from star imports
(F405)
75-75: assert_np_equal may be undefined, or defined from star imports
(F405)
76-76: assert_np_equal may be undefined, or defined from star imports
(F405)
77-77: assert_np_equal may be undefined, or defined from star imports
(F405)
warp/_src/builtins.py
9269-9269: Float may be undefined, or defined from star imports
(F405)
9270-9270: Float may be undefined, or defined from star imports
(F405)
9300-9300: tile may be undefined, or defined from star imports
(F405)
9300-9300: Float may be undefined, or defined from star imports
(F405)
9300-9300: Tuple may be undefined, or defined from star imports
(F405)
9301-9301: tile may be undefined, or defined from star imports
(F405)
9301-9301: Float may be undefined, or defined from star imports
(F405)
9301-9301: Tuple may be undefined, or defined from star imports
(F405)
9302-9302: Float may be undefined, or defined from star imports
(F405)
🔇 Additional comments (6)
warp/_src/codegen.py (1)
1258-1261: Float literals now treated as SSA constants — good catch.Mirrors the int path and prevents unlabeled python floats from leaking past
register_var. Nicely done.warp/examples/tile/example_tile_block_cholesky.py (1)
81-81: Good fusion: use alpha=-1.0 to fold subtraction into matmul.Removes intermediates and shared-memory traffic; aligns with new API.
Also applies to: 110-110, 151-151, 166-166
warp/tests/tile/test_tile_mathdx.py (1)
43-45: Semantics and gradients look correct.Forward and backward checks match out = α·A@B + β·C and adjoints (excluding α/β).
Also applies to: 69-78
warp/tests/tile/test_tile_cholesky.py (1)
373-374: LGTM: fused updates via alpha=-1.0.Cleaner, fewer temporaries, consistent with new API.
Also applies to: 388-389, 415-416, 429-430
warp/_src/builtins.py (1)
9148-9186: Use a typed Var for beta; verify gradient scaling by alpha
- In the return-value path,
betais set to a raw float (0.0). The dispatch pipeline expects scalar runtime args asVar. Replace with a constantVarof the correct dtype:@@ - if len(return_values) > 0: - beta = 0.0 # for c = tile_matmul(a,b) case we want to overwrite c value + if len(return_values) > 0: + beta = Var(None, type=out.type.dtype, constant=0.0)
- I couldn’t locate the backward kernels in this file. Please verify that the gradient implementation applies the
alphafactor (i.e. dA = alpha * dC @ Bᵀ, dB = alpha * Aᵀ @ dC) whenalpha != 1. If not, threadalphathrough the adjoint or scaledCaccordingly.warp/native/tile.h (1)
3173-3204: LGTM: in-place adjoint scales adj_C by beta after computing dA/dBThis ordering avoids contaminating dA/dB with beta; correct for reverse-mode. No action needed.
Please ensure tests cover the in-place variant with beta ≠ {0,1} for both real and complex dtypes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (5)
docs/modules/functions.rst (2)
2681-2682: Use consistent “default: …” formatting in params.Elsewhere in this doc you use “default: …”. Align alpha/beta to match.
- :param alpha: Scaling factor (default 1.0) - :param beta: Accumulator factor (default 1.0) + :param alpha: Scaling factor (default: 1.0) + :param beta: Accumulator factor (default: 1.0)
2686-2706: Minor clarity: phrase return semantics; keep default formatting consistent.
- Suggest phrasing the semantics as a return value to avoid “out = …” when no out param is present.
- Keep “default: …” formatting consistent.
- Computes the matrix product ``out = alpha * a*b``. + Computes the matrix product and returns ``alpha * a * b``.- :param alpha: Scaling factor (default 1.0) + :param alpha: Scaling factor (default: 1.0)warp/__init__.pyi (2)
5286-5287: Fix backtick formatting in parameter documentation.The backticks are malformed (stray closing backticks). Additionally, the docstrings claim "(default 1.0)" but the function signature shows these as required parameters without defaults. Either add defaults to the signature (
alpha: Float = 1.0, beta: Float = 1.0) or remove the "(default 1.0)" text from the docstrings.Apply this diff to fix the formatting and align with the signature:
- :param alpha: Scaling factor (default 1.0)`` - :param beta: Accumulator factor (default 1.0)`` + :param alpha: Scaling factor + :param beta: Accumulator factorOr if defaults are intended:
def tile_matmul( a: Tile[Float, Tuple[int, int]], b: Tile[Float, Tuple[int, int]], out: Tile[Float, Tuple[int, int]], - alpha: Float, - beta: Float, + alpha: Float = 1.0, + beta: Float = 1.0, ):
5309-5309: Fix backtick formatting and clarify default parameter value.The backtick is malformed (stray closing backtick). Additionally, the docstring claims "(default 1.0)" but the function signature shows
alphaas a required parameter. Either add a default to the signature (alpha: Float = 1.0) or remove the "(default 1.0)" text from the docstring.Apply this diff to fix the formatting and align with the signature:
- :param alpha: Scaling factor (default 1.0)` + :param alpha: Scaling factorOr if a default is intended:
def tile_matmul( - a: Tile[Float, Tuple[int, int]], b: Tile[Float, Tuple[int, int]], alpha: Float + a: Tile[Float, Tuple[int, int]], b: Tile[Float, Tuple[int, int]], alpha: Float = 1.0 ) -> Tile[Float, Tuple[int, int]]:warp/_src/builtins.py (1)
9148-9156: Wrap implicit beta=0.0 in a Var (codegen will crash on bare float).LTO/codegen expects runtime args as Var-like. Assigning a Python float to beta in the return‑value overload will fail at serialization. Construct a typed Var using alpha.type.
Apply:
- beta = 0.0 # for c = tile_matmul(a,b) case we want to overwrite c value + beta = Var(label=None, type=alpha.type, constant=0.0) # ensure typed constant for codegen
🧹 Nitpick comments (3)
docs/modules/functions.rst (1)
2659-2677: API/semantics LGTM; clarify alpha/beta type for complex tiles.Looks good. One request: for complex tiles (vec2h/f/d), are alpha/beta allowed to be complex or real-only? If real-only, please state that explicitly; if complex is supported, consider documenting that alpha/beta match the tile dtype.
Optional nit: “Computing the adjoints of alpha and beta are not yet supported.” → use singular “is” or rephrase for clarity.
- Note that computing the adjoints of alpha and beta are not yet supported. + Note that computing the adjoints of alpha and beta is not yet supported.warp/__init__.pyi (1)
5305-5305: Use singular "adjoint" for a single parameter.Since this overload only has the
alphaparameter, the note should use the singular form "adjoint" instead of plural "adjoints".Apply this diff:
- Note that computing the adjoints of alpha is not yet supported. + Note that computing the adjoint of alpha is not yet supported.warp/_src/builtins.py (1)
9269-9273: Signatures restrict alpha/beta to Float; docs say vec2(Complex) tiles supported.If A/B/out are complex (vec2h/f/d), Float alpha/beta block complex scaling and may mismatch backend expectations. Consider adding overloads with alpha/beta: vector(length=2, dtype=Float), or clarify docs if complex scaling isn’t supported.
- Option A: Add parallel add_builtin overloads for complex tiles with alpha/beta as vector(length=2, dtype=Float).
- Option B: If only real scaling is supported for complex tiles, update docs to state alpha/beta must be real.
Can you confirm intended behavior and backend support for complex alpha/beta? Based on learnings
Also applies to: 9299-9305
📜 Review details
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
docs/modules/functions.rst(4 hunks)warp/__init__.pyi(3 hunks)warp/_src/builtins.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
warp/_src/builtins.py (2)
warp/_src/context.py (3)
value_func(68-73)value_func(1333-1334)add_builtin(1255-1494)warp/_src/types.py (2)
tile(3992-4081)dtype(4674-4680)
🪛 Ruff (0.13.3)
warp/_src/builtins.py
9269-9269: Float may be undefined, or defined from star imports
(F405)
9270-9270: Float may be undefined, or defined from star imports
(F405)
9300-9300: tile may be undefined, or defined from star imports
(F405)
9300-9300: Float may be undefined, or defined from star imports
(F405)
9300-9300: Tuple may be undefined, or defined from star imports
(F405)
9301-9301: tile may be undefined, or defined from star imports
(F405)
9301-9301: Float may be undefined, or defined from star imports
(F405)
9301-9301: Tuple may be undefined, or defined from star imports
(F405)
9302-9302: Float may be undefined, or defined from star imports
(F405)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enhances the wp.tile_matmul() function by adding alpha and beta scaling parameters, implementing the formula out = alpha * (A @ B) + beta * out to match the cuBLASDx API. The enhancement enables fused matrix multiplication with scaling and accumulation operations, reducing memory usage and improving performance by eliminating the need for temporary intermediate results.
The changes include two function overloads: one for in-place accumulation with both alpha and beta parameters (defaulting to 1.0), and another that returns a new tile with only an alpha parameter. The API maintains backward compatibility while enabling more efficient operations, as demonstrated in the updated Cholesky decomposition example where A -= L * L^T can now be expressed as wp.tile_matmul(L, L_T, A, alpha=-1.0).
Important Files Changed
Changed Files
| Filename | Score | Overview |
|---|---|---|
| warp/native/tile.h | 3/5 | Enhanced C++ implementation to support alpha/beta scaling in tile matrix multiplication |
| warp/_src/builtins.py | 4/5 | Updated dispatch function to handle alpha/beta parameters and distinguish between overloads |
| warp/init.pyi | 5/5 | Auto-generated type stubs updated with new function signatures and documentation |
| docs/modules/functions.rst | 5/5 | Documentation updated to describe new parameters and overload variants |
| warp/_src/codegen.py | 5/5 | Added support for float constants in variable registration for code generation |
| warp/tests/tile/test_tile_mathdx.py | 4/5 | Updated tests to validate alpha/beta scaling functionality and gradient computations |
| warp/tests/tile/test_tile_cholesky.py | 5/5 | Modified to use new alpha parameter for in-place matrix operations |
| warp/examples/tile/example_tile_block_cholesky.py | 5/5 | Optimized example using fused operations to reduce memory usage |
| CHANGELOG.md | 5/5 | Added changelog entry documenting the new feature |
Confidence score: 3/5
- This PR introduces significant enhancements but has some potential issues in the backward pass implementation that need attention
- Score lowered due to concerns in tile.h where gradient calculations may not properly handle non-unit beta values, and potential type conversion issues between template parameters
- Pay close attention to warp/native/tile.h, particularly the backward pass logic for gradient computation with beta scaling
Additional Comments (4)
-
warp/tests/tile/test_tile_mathdx.py, line 77 (link)logic: C gradient calculation appears incorrect - should be
-1.3 * adj_Cnotadj_C - 1.3 * adj_C -
warp/native/tile.h, line 3197-3201 (link)logic: This multiplication by beta could accumulate incorrectly if the adjoint function is called multiple times on the same tile, as it modifies
adj_C.grad(i)in place. -
warp/native/tile.h, line 3157-3158 (link)style: Type conversion from template parameters to tile type could cause precision loss. Consider using
static_castwith appropriate bounds checking for safety. -
warp/native/tile.h, line 3181-3184 (link)logic: The backward pass assumes beta=1.0 for gradient scaling, which may not produce correct gradients when the forward pass used beta != 1.0.
9 files reviewed, 4 comments
b058026 to
20ca408
Compare
Signed-off-by: Roland Schwan <roland.schwan@mikrounix.com>
20ca408 to
988737a
Compare
|
Rebased and squashed the commits |
Description
This PR adds alpha and beta scalings to
tile_matmulas in the cublasdx API. The main motivation is to generalize thetile_matmulAPI with the aim to reduce the number of operations and shared memory of otherwise necessary operations. For example, in the tile_block_cholesky examplethis reduces shared memory usage (
L_L_T_blockdoesn't exist) and fuses the subtraction.Before your PR is "Ready for review"
__init__.pyi,functions.rst)pre-commit run -aSummary by CodeRabbit
New Features
Documentation
Examples
Tests