Add cuSolverDx JIT fusion and solver projections#1176
Add cuSolverDx JIT fusion and solver projections#1176cliffburdick wants to merge 1 commit intomainfrom
Conversation
Upgrade MathDx/libmathdx integration to the latest runtime codegen packages, preserve runtime descriptor queries for FFT and BLAS, add cuSolverDx-backed JIT support for solver operators, and introduce lazy solver projections so multi-output APIs like QR, LU, SVD, and eig can participate in single expressions with tests covering the fused and projection paths.
|
/build |
Greptile SummaryThis PR upgrades the MathDx/libmathdx integration to version 26.03, adds cuSolverDx-backed JIT fusion for Cholesky and matrix inverse, and introduces a
Confidence Score: 3/5The JIT fusion path for cuSolverDx carries two correctness defects that would produce silent wrong answers for specific inputs before triggering any observable error. The rank > 4 JIT kernel bug means any user with a batched solver expression of rank 5 or higher gets a kernel that runs the solver on uninitialised shared memory without any assertion or launch failure. The shared memory undercount for GESV/inverse could cause smem region overlap. Both are silent — easy to miss if the test suite only exercises rank-2 and rank-3 cases. The projection and resource-leak issues are lower urgency but should also be addressed before the new JIT paths go into production use. include/matx/transforms/solver_cusolverdx.h (rank guard and smem sizing), include/matx/operators/solver_projection.h (double-release ordering), include/matx/operators/inverse.h (shares both issues from solver_cusolverdx.h) Important Files Changed
Reviews (1): Last reviewed commit: "Add cuSolverDx JIT fusion and solver pro..." | Re-trigger Greptile |
| for (index_t linear = tid; linear < elems; linear += blockDim.x * blockDim.y * blockDim.z) { | ||
| const index_t row = linear / n; | ||
| const index_t col = linear % n; | ||
| if constexpr (Rank() == 2) { | ||
| smem_a[linear] = a_.template operator()<CapType>(row, col); | ||
| } | ||
| else if constexpr (Rank() == 3) { | ||
| smem_a[linear] = a_.template operator()<CapType>(idx[0], row, col); | ||
| } | ||
| else if constexpr (Rank() == 4) { | ||
| smem_a[linear] = a_.template operator()<CapType>(idx[0], idx[1], row, col); | ||
| } | ||
| } | ||
|
|
||
| if (tid == 0) { | ||
| *info = 0; | ||
| } | ||
| __syncthreads(); | ||
| )"; | ||
| result += solver_func_name; | ||
| result += R"((smem_a, info); | ||
| __syncthreads(); | ||
|
|
||
| if (tid < elems) { | ||
| return smem_a[tid]; | ||
| } | ||
| return value_type{}; | ||
| )"; | ||
| return result; | ||
| } | ||
|
|
||
| std::string GetGesvInverseFuncStr(const std::string &solver_func_name) const | ||
| { | ||
| std::string result = R"( | ||
| using value_type = )"; | ||
| result += detail::type_to_string<InputType>(); | ||
| result += R"(; | ||
| static constexpr index_t n = )"; | ||
| result += std::to_string(static_cast<int>(n_)); | ||
| result += R"(; | ||
| static constexpr index_t elems = n * n; | ||
| extern __shared__ __align__(16) char smem[]; | ||
| value_type* smem_a = reinterpret_cast<value_type*>(smem); | ||
| value_type* smem_b = reinterpret_cast<value_type*>(smem + elems * sizeof(value_type)); | ||
| int* ipiv = reinterpret_cast<int*>(smem + (2 * elems * sizeof(value_type))); | ||
| int* info = reinterpret_cast<int*>(smem + (2 * elems * sizeof(value_type)) + n * sizeof(int)); | ||
| const int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; | ||
| cuda::std::array<index_t, Rank()> idx = { static_cast<index_t>(indices)... }; | ||
|
|
||
| for (index_t linear = tid; linear < elems; linear += blockDim.x * blockDim.y * blockDim.z) { | ||
| const index_t row = linear / n; | ||
| const index_t col = linear % n; | ||
| if constexpr (Rank() == 2) { | ||
| smem_a[linear] = a_.template operator()<CapType>(row, col); | ||
| } | ||
| else if constexpr (Rank() == 3) { | ||
| smem_a[linear] = a_.template operator()<CapType>(idx[0], row, col); | ||
| } | ||
| else if constexpr (Rank() == 4) { | ||
| smem_a[linear] = a_.template operator()<CapType>(idx[0], idx[1], row, col); | ||
| } | ||
| smem_b[linear] = row == col ? value_type{1} : value_type{0}; |
There was a problem hiding this comment.
Rank > 4 silently produces wrong results in JIT kernels
Both GetPotrfFuncStr and GetGesvInverseFuncStr generate code with branches only for Rank() == 2, Rank() == 3, and Rank() == 4. When Rank() > 4, none of the if constexpr arms match, so the loop iteration body becomes a no-op and smem_a is never loaded with input data — the solver runs on garbage shared memory. Because get_capability(SUPPORTS_JIT) has no upper-bound rank check (OpA::Rank() >= 2 is the only condition), a rank-5 tensor will pass the support check and silently produce incorrect results at runtime. The same issue affects the GetGesvInverseFuncStr path used by InvOp.
| } | ||
| else { | ||
| if (m_ <= 0 || n_ <= 0) { | ||
| return false; | ||
| } | ||
| auto handle = GeneratePlan(); | ||
| long long int shm = 0; | ||
| const bool supported = | ||
| cusolverdxGetTraitInt64(handle, CUSOLVERDX_TRAIT_SHARED_MEMORY_SIZE, &shm) == COMMONDX_SUCCESS; | ||
| cusolverdxDestroyDescriptor(handle); | ||
| return supported; |
There was a problem hiding this comment.
Shared memory calculation may be insufficient for GESV/inverse
GetShmRequired computes max(cusolverdx_shm, matrix_bytes + info_bytes) where matrix_bytes = m * n * sizeof(T). However, the kernel in GetGesvInverseFuncStr lays out four distinct regions in shared memory: smem_a (n² × sizeof(T)), smem_b (n² × sizeof(T)), ipiv (n × sizeof(int)), and info (sizeof(int)) — a total of 2·n²·sizeof(T) + (n+1)·sizeof(int). The fallback floor matrix_bytes + info_bytes only covers one matrix's worth plus 20 bytes. If CUSOLVERDX_TRAIT_SHARED_MEMORY_SIZE does not already account for all user-visible buffers (A, B, ipiv, info), the kernel will overrun the allocated shared memory region and corrupt adjacent state.
| static_cast<int>(block_dim[2])}; | ||
| } | ||
|
|
||
| int GetWorkspaceSize() const | ||
| { | ||
| auto handle = GeneratePlan(); | ||
| long long int workspace_size = 0; | ||
| LIBCUSOLVERDX_CHECK(cusolverdxGetTraitInt64(handle, CUSOLVERDX_TRAIT_WORKSPACE_SIZE, &workspace_size)); | ||
| cusolverdxDestroyDescriptor(handle); | ||
| return static_cast<int>(workspace_size); | ||
| } | ||
|
|
||
| bool GenerateLTOIR(std::set<std::string> <oir_symbols) | ||
| { | ||
| LTOIRData ltoir; | ||
| const auto symbol_name = std::string(SOLVER_DX_FUNC_PREFIX) + "_" + GetSymbolName(); | ||
| ltoir_symbols.insert(symbol_name); | ||
|
|
||
| if (detail::GetCache().GetLTOIRCachedBytes(symbol_name) != nullptr) { | ||
| return true; | ||
| } | ||
|
|
||
| auto handle = GeneratePlan(); | ||
| LIBCUSOLVERDX_CHECK(cusolverdxSetOptionStr(handle, COMMONDX_OPTION_SYMBOL_NAME, symbol_name.c_str())); | ||
| const auto trait_symbol_name = GetTraitSymbolName(handle); | ||
| MATX_ASSERT_STR(trait_symbol_name == symbol_name, | ||
| matxInvalidParameter, | ||
| "cuSolverDx returned an unexpected symbol name"); | ||
|
|
||
| commondxCode code; | ||
| LIBCUSOLVERDX_CHECK(commondxCreateCode(&code)); | ||
| LIBCUSOLVERDX_CHECK(commondxSetCodeOptionInt64(code, COMMONDX_OPTION_TARGET_SM, cc_)); | ||
| LIBCUSOLVERDX_CHECK(cusolverdxFinalizeCode(code, handle)); | ||
|
|
||
| LIBCUSOLVERDX_CHECK(commondxGetCodeLTOIRSize(code, <oir.length)); | ||
| ltoir.data = static_cast<char*>(malloc(ltoir.length)); | ||
| MATX_ASSERT_STR(ltoir.data != nullptr, matxInvalidParameter, "Failed to allocate cuSolverDx LTOIR data"); | ||
| LIBCUSOLVERDX_CHECK(commondxGetCodeLTOIR(code, ltoir.length, ltoir.data)); | ||
|
|
||
| if (!detail::GetCache().StoreLTOIRCachedBytes(symbol_name, ltoir.data, ltoir.length)) { | ||
| free(ltoir.data); | ||
| MATX_LOG_ERROR("Failed to store cuSolverDx LTOIR cached bytes for: {}", symbol_name); | ||
| return false; | ||
| } | ||
|
|
||
| ltoir.data = nullptr; | ||
| ltoir.length = 0; | ||
|
|
||
| LIBCUSOLVERDX_CHECK(commondxDestroyCode(code)); | ||
| LIBCUSOLVERDX_CHECK(cusolverdxDestroyDescriptor(handle)); | ||
|
|
||
| return true; |
There was a problem hiding this comment.
cusolverdxDescriptor and commondxCode leaked on any exception path
GenerateLTOIR acquires handle via GeneratePlan() and code via commondxCreateCode, but neither is wrapped in RAII. A LIBCUSOLVERDX_CHECK assertion between acquisition and the paired cusolverdxDestroyDescriptor / commondxDestroyCode calls will leave both objects unreleased. The same pattern occurs in GetShmRequired, GetBlockDim, GetWorkspaceSize, and GetTraitSymbolName. Additionally, ltoir.data (allocated with malloc on line 292) is leaked if commondxGetCodeLTOIR or a subsequent check throws before StoreLTOIRCachedBytes is reached.
| template <typename ShapeType, typename Executor> | ||
| __MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept | ||
| { | ||
| state_->Materialize(std::forward<Executor>(ex)); | ||
| tensor_ = state_->template Tensor<Component>(); | ||
| } | ||
|
|
||
| template <typename ShapeType, typename Executor> | ||
| __MATX_INLINE__ void PostRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept | ||
| { | ||
| state_->Release(std::forward<Executor>(ex)); | ||
| } |
There was a problem hiding this comment.
Double-Release when two projections from the same op are composed
When two SolverProjectionOp instances sharing the same State* — e.g., lu().LU and lu().Piv — appear in the same expression, both call state_->Release(ex) in PostRun. The first call frees both scratch buffers and resets materialized_ = false. The second call is a no-op, which is safe in the normal post-kernel flow. However, the tensor_ member of the second projection still holds the now-freed pointer captured during PreRun; if PostRun ordering is ever interleaved with another PreRun on the same state, the second projection can access freed memory.
Upgrade MathDx/libmathdx integration to the latest runtime codegen packages, preserve runtime descriptor queries for FFT and BLAS, add cuSolverDx-backed JIT support for solver operators, and introduce lazy solver projections so multi-output APIs like QR, LU, SVD, and eig can participate in single expressions with tests covering the fused and projection paths.
Also added new interface to allow multi-output return transforms to be used in a fusion context.