From 9bf4abd0370dcf05d17094dbe5646aefd4c56b44 Mon Sep 17 00:00:00 2001 From: magaonka Date: Mon, 13 Apr 2026 17:25:28 -0500 Subject: [PATCH] [ROCm] Fix LoadKernel to use refcounted module path for proper cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ROCm's RocmExecutor has two kernel/module loading paths with different caching behavior: 1. LoadModule → LoadModuleFromHsaco → populates both gpu_binary_to_module_ (refcounted) and in_memory_modules_ 2. LoadKernel → directly inserts into in_memory_modules_ only, bypassing gpu_binary_to_module_ entirely When UnloadKernel calls UnloadGpuBinary, it only looks in gpu_binary_to_module_ for cleanup. Since LoadKernel never populated that map, the cleanup was a no-op: in_memory_modules_ entries and loaded hipModule_t objects were never removed, leaking until executor destruction. This caused stale module cache entries when CustomKernelThunk's owned HSACO buffers were freed and reallocated at the same address. The cache returned old modules that didn't contain the expected kernels, producing flaky hipErrorNotFound failures in sort and LuSolve tests under xdist parallelism. Fix by routing LoadKernel through LoadModuleFromHsaco, so every loaded module participates in the refcount mechanism and is properly cleaned up when the last kernel referencing it is destroyed. --- xla/stream_executor/rocm/rocm_executor.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 024ef0b07ecf1..6d7aa97791ac1 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -660,12 +660,8 @@ absl::StatusOr> RocmExecutor::LoadKernel( const char* hsaco = reinterpret_cast( spec.cuda_cubin_in_memory()->cubin_bytes.data()); absl::MutexLock lock{in_memory_modules_mu_}; - ModuleHandle module_handle{hsaco}; - hipModule_t& module = in_memory_modules_[module_handle]; - - if (module == nullptr) { - TF_ASSIGN_OR_RETURN(module, LoadHsaco(&rocm_context_, hsaco)); - } + TF_ASSIGN_OR_RETURN(ModuleHandle module_handle, LoadModuleFromHsaco(hsaco)); + hipModule_t module = gpu_binary_to_module_.at(module_handle).first; kernel_to_gpu_binary_[rocm_kernel.get()] = module_handle; VLOG(2) << "getting function " << kernel_name << " from module " << module;