diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 024ef0b07ecf1..6d7aa97791ac1 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -660,12 +660,8 @@ absl::StatusOr> RocmExecutor::LoadKernel( const char* hsaco = reinterpret_cast( spec.cuda_cubin_in_memory()->cubin_bytes.data()); absl::MutexLock lock{in_memory_modules_mu_}; - ModuleHandle module_handle{hsaco}; - hipModule_t& module = in_memory_modules_[module_handle]; - - if (module == nullptr) { - TF_ASSIGN_OR_RETURN(module, LoadHsaco(&rocm_context_, hsaco)); - } + TF_ASSIGN_OR_RETURN(ModuleHandle module_handle, LoadModuleFromHsaco(hsaco)); + hipModule_t module = gpu_binary_to_module_.at(module_handle).first; kernel_to_gpu_binary_[rocm_kernel.get()] = module_handle; VLOG(2) << "getting function " << kernel_name << " from module " << module;