Skip to content

Add cuSolverDx JIT fusion and solver projections#1176

Open
cliffburdick wants to merge 1 commit intomainfrom
cburdick/cusolverdx-jit-fusion
Open

Add cuSolverDx JIT fusion and solver projections#1176
cliffburdick wants to merge 1 commit intomainfrom
cburdick/cusolverdx-jit-fusion

Conversation

@cliffburdick
Copy link
Copy Markdown
Collaborator

Upgrade MathDx/libmathdx integration to the latest runtime codegen packages, preserve runtime descriptor queries for FFT and BLAS, add cuSolverDx-backed JIT support for solver operators, and introduce lazy solver projections so multi-output APIs like QR, LU, SVD, and eig can participate in single expressions with tests covering the fused and projection paths.

Also added new interface to allow multi-output return transforms to be used in a fusion context.

Upgrade MathDx/libmathdx integration to the latest runtime codegen packages, preserve runtime descriptor queries for FFT and BLAS, add cuSolverDx-backed JIT support for solver operators, and introduce lazy solver projections so multi-output APIs like QR, LU, SVD, and eig can participate in single expressions with tests covering the fused and projection paths.
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 9, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 9, 2026

Greptile Summary

This PR upgrades the MathDx/libmathdx integration to version 26.03, adds cuSolverDx-backed JIT fusion for Cholesky and matrix inverse, and introduces a SolverProjectionOp / *State pattern so multi-output solvers (QR, LU, SVD, eig) can expose individual outputs (.Q, .R, .LU, .Piv, etc.) as lazy operators usable inside single-expression fusions.

  • cuSolverDx JIT backend (solver_cusolverdx.h): generates LTOIR via libcusolverdx, injects it through nvJitLink at launch, and provides block-dim/shared-memory traits to the capability system.
  • Lazy solver projections (solver_projection.h, lu.h, qr.h, svd.h, eig.h): each multi-output solver now holds a shared *State object; SolverProjectionOp wraps one output component and drives Materialize/Release through PreRun/PostRun.
  • set.h JIT bounds guard: a per-index validity check is inserted in the generated assignment kernel to avoid out-of-bounds writes when the input operator's domain exceeds the output tensor's extents.

Confidence Score: 3/5

The JIT fusion path for cuSolverDx carries two correctness defects that would produce silent wrong answers for specific inputs before triggering any observable error.

The rank > 4 JIT kernel bug means any user with a batched solver expression of rank 5 or higher gets a kernel that runs the solver on uninitialised shared memory without any assertion or launch failure. The shared memory undercount for GESV/inverse could cause smem region overlap. Both are silent — easy to miss if the test suite only exercises rank-2 and rank-3 cases. The projection and resource-leak issues are lower urgency but should also be addressed before the new JIT paths go into production use.

include/matx/transforms/solver_cusolverdx.h (rank guard and smem sizing), include/matx/operators/solver_projection.h (double-release ordering), include/matx/operators/inverse.h (shares both issues from solver_cusolverdx.h)

Important Files Changed

Filename Overview
include/matx/transforms/solver_cusolverdx.h New cuSolverDx JIT backend; contains rank > 4 silent correctness bug, potential shared memory undercount for GESV, and resource leak paths in GenerateLTOIR
include/matx/operators/solver_projection.h New lazy projection operator for multi-output solver results; potential double-release ordering concern when two projections from the same state are used together
include/matx/operators/chol.h Adds cuSolverDx JIT path via dx_potrf_helper_; rank constraint and get_capability logic look correct for the supported (rank 2-4) range
include/matx/operators/inverse.h Adds cuSolverDx GESV JIT path for matrix inverse; shares the shared-memory undercount concern from solver_cusolverdx.h and the rank>4 issue
include/matx/operators/lu.h Refactored to use LUState/shared_ptr pattern; adds SolverProjectionOp members LU and Piv for single-expression composition
include/matx/operators/qr.h Refactored to QRState; adds Q and R projection members; qr_solver also refactored with SolverQRState
include/matx/operators/svd.h Refactored to SVDState; U, S, VT projection members added; shape computation for reduced mode looks correct
include/matx/operators/eig.h Refactored to EigState with Vectors/Values projections; straightforward state migration with no new logic risks
include/matx/operators/set.h Adds bounds guard in JIT-generated assignment code; evaluates op_ for all indices (including out-of-bounds) before the validity check, which is benign for most operators
include/matx/core/nvrtc_helper.h Adds conditional compile-time path definitions for MathDx/CCCL and a cusolverdx linker injection; environment variable fallback for fatbin/library path is functional
cmake/FindMathDx.cmake Substantially simplified; now hard-requires CUDA 13+, pins to MathDx 26.03 and libmathdx 0.3.2, adds cusolverdx component; warning-to-fatal_error is a breaking change for CUDA 12 users
CMakeLists.txt Adds cusolverdx link targets, CCCL source dir definition, and fatbin/library detection; straightforward build system updates

Reviews (1): Last reviewed commit: "Add cuSolverDx JIT fusion and solver pro..." | Re-trigger Greptile

Comment on lines +327 to +388
for (index_t linear = tid; linear < elems; linear += blockDim.x * blockDim.y * blockDim.z) {
const index_t row = linear / n;
const index_t col = linear % n;
if constexpr (Rank() == 2) {
smem_a[linear] = a_.template operator()<CapType>(row, col);
}
else if constexpr (Rank() == 3) {
smem_a[linear] = a_.template operator()<CapType>(idx[0], row, col);
}
else if constexpr (Rank() == 4) {
smem_a[linear] = a_.template operator()<CapType>(idx[0], idx[1], row, col);
}
}

if (tid == 0) {
*info = 0;
}
__syncthreads();
)";
result += solver_func_name;
result += R"((smem_a, info);
__syncthreads();

if (tid < elems) {
return smem_a[tid];
}
return value_type{};
)";
return result;
}

std::string GetGesvInverseFuncStr(const std::string &solver_func_name) const
{
std::string result = R"(
using value_type = )";
result += detail::type_to_string<InputType>();
result += R"(;
static constexpr index_t n = )";
result += std::to_string(static_cast<int>(n_));
result += R"(;
static constexpr index_t elems = n * n;
extern __shared__ __align__(16) char smem[];
value_type* smem_a = reinterpret_cast<value_type*>(smem);
value_type* smem_b = reinterpret_cast<value_type*>(smem + elems * sizeof(value_type));
int* ipiv = reinterpret_cast<int*>(smem + (2 * elems * sizeof(value_type)));
int* info = reinterpret_cast<int*>(smem + (2 * elems * sizeof(value_type)) + n * sizeof(int));
const int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;
cuda::std::array<index_t, Rank()> idx = { static_cast<index_t>(indices)... };

for (index_t linear = tid; linear < elems; linear += blockDim.x * blockDim.y * blockDim.z) {
const index_t row = linear / n;
const index_t col = linear % n;
if constexpr (Rank() == 2) {
smem_a[linear] = a_.template operator()<CapType>(row, col);
}
else if constexpr (Rank() == 3) {
smem_a[linear] = a_.template operator()<CapType>(idx[0], row, col);
}
else if constexpr (Rank() == 4) {
smem_a[linear] = a_.template operator()<CapType>(idx[0], idx[1], row, col);
}
smem_b[linear] = row == col ? value_type{1} : value_type{0};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Rank > 4 silently produces wrong results in JIT kernels

Both GetPotrfFuncStr and GetGesvInverseFuncStr generate code with branches only for Rank() == 2, Rank() == 3, and Rank() == 4. When Rank() > 4, none of the if constexpr arms match, so the loop iteration body becomes a no-op and smem_a is never loaded with input data — the solver runs on garbage shared memory. Because get_capability(SUPPORTS_JIT) has no upper-bound rank check (OpA::Rank() >= 2 is the only condition), a rank-5 tensor will pass the support check and silently produce incorrect results at runtime. The same issue affects the GetGesvInverseFuncStr path used by InvOp.

Comment on lines +222 to +232
}
else {
if (m_ <= 0 || n_ <= 0) {
return false;
}
auto handle = GeneratePlan();
long long int shm = 0;
const bool supported =
cusolverdxGetTraitInt64(handle, CUSOLVERDX_TRAIT_SHARED_MEMORY_SIZE, &shm) == COMMONDX_SUCCESS;
cusolverdxDestroyDescriptor(handle);
return supported;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Shared memory calculation may be insufficient for GESV/inverse

GetShmRequired computes max(cusolverdx_shm, matrix_bytes + info_bytes) where matrix_bytes = m * n * sizeof(T). However, the kernel in GetGesvInverseFuncStr lays out four distinct regions in shared memory: smem_a (n² × sizeof(T)), smem_b (n² × sizeof(T)), ipiv (n × sizeof(int)), and info (sizeof(int)) — a total of 2·n²·sizeof(T) + (n+1)·sizeof(int). The fallback floor matrix_bytes + info_bytes only covers one matrix's worth plus 20 bytes. If CUSOLVERDX_TRAIT_SHARED_MEMORY_SIZE does not already account for all user-visible buffers (A, B, ipiv, info), the kernel will overrun the allocated shared memory region and corrupt adjacent state.

Comment on lines +257 to +308
static_cast<int>(block_dim[2])};
}

int GetWorkspaceSize() const
{
auto handle = GeneratePlan();
long long int workspace_size = 0;
LIBCUSOLVERDX_CHECK(cusolverdxGetTraitInt64(handle, CUSOLVERDX_TRAIT_WORKSPACE_SIZE, &workspace_size));
cusolverdxDestroyDescriptor(handle);
return static_cast<int>(workspace_size);
}

bool GenerateLTOIR(std::set<std::string> &ltoir_symbols)
{
LTOIRData ltoir;
const auto symbol_name = std::string(SOLVER_DX_FUNC_PREFIX) + "_" + GetSymbolName();
ltoir_symbols.insert(symbol_name);

if (detail::GetCache().GetLTOIRCachedBytes(symbol_name) != nullptr) {
return true;
}

auto handle = GeneratePlan();
LIBCUSOLVERDX_CHECK(cusolverdxSetOptionStr(handle, COMMONDX_OPTION_SYMBOL_NAME, symbol_name.c_str()));
const auto trait_symbol_name = GetTraitSymbolName(handle);
MATX_ASSERT_STR(trait_symbol_name == symbol_name,
matxInvalidParameter,
"cuSolverDx returned an unexpected symbol name");

commondxCode code;
LIBCUSOLVERDX_CHECK(commondxCreateCode(&code));
LIBCUSOLVERDX_CHECK(commondxSetCodeOptionInt64(code, COMMONDX_OPTION_TARGET_SM, cc_));
LIBCUSOLVERDX_CHECK(cusolverdxFinalizeCode(code, handle));

LIBCUSOLVERDX_CHECK(commondxGetCodeLTOIRSize(code, &ltoir.length));
ltoir.data = static_cast<char*>(malloc(ltoir.length));
MATX_ASSERT_STR(ltoir.data != nullptr, matxInvalidParameter, "Failed to allocate cuSolverDx LTOIR data");
LIBCUSOLVERDX_CHECK(commondxGetCodeLTOIR(code, ltoir.length, ltoir.data));

if (!detail::GetCache().StoreLTOIRCachedBytes(symbol_name, ltoir.data, ltoir.length)) {
free(ltoir.data);
MATX_LOG_ERROR("Failed to store cuSolverDx LTOIR cached bytes for: {}", symbol_name);
return false;
}

ltoir.data = nullptr;
ltoir.length = 0;

LIBCUSOLVERDX_CHECK(commondxDestroyCode(code));
LIBCUSOLVERDX_CHECK(cusolverdxDestroyDescriptor(handle));

return true;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 cusolverdxDescriptor and commondxCode leaked on any exception path

GenerateLTOIR acquires handle via GeneratePlan() and code via commondxCreateCode, but neither is wrapped in RAII. A LIBCUSOLVERDX_CHECK assertion between acquisition and the paired cusolverdxDestroyDescriptor / commondxDestroyCode calls will leave both objects unreleased. The same pattern occurs in GetShmRequired, GetBlockDim, GetWorkspaceSize, and GetTraitSymbolName. Additionally, ltoir.data (allocated with malloc on line 292) is leaked if commondxGetCodeLTOIR or a subsequent check throws before StoreLTOIRCachedBytes is reached.

Comment on lines +107 to +118
template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
{
state_->Materialize(std::forward<Executor>(ex));
tensor_ = state_->template Tensor<Component>();
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PostRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
{
state_->Release(std::forward<Executor>(ex));
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Double-Release when two projections from the same op are composed

When two SolverProjectionOp instances sharing the same State* — e.g., lu().LU and lu().Piv — appear in the same expression, both call state_->Release(ex) in PostRun. The first call frees both scratch buffers and resets materialized_ = false. The second call is a no-op, which is safe in the normal post-kernel flow. However, the tensor_ member of the second projection still holds the now-freed pointer captured during PreRun; if PostRun ordering is ever interleaved with another PreRun on the same state, the second projection can access freed memory.

@coveralls
Copy link
Copy Markdown

Coverage Status

Coverage is 94.402%cburdick/cusolverdx-jit-fusion into main. No base build found for main.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants