diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 623244dfae597a..c4d4f618889fa8 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3630,7 +3630,7 @@ cc_library( features = ["-layering_check"], local_defines = if_cuda(["GOOGLE_CUDA=1"]) + if_rocm(["TENSORFLOW_USE_ROCM=1"]), deps = if_cuda_or_rocm([ - "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/stream_executor/gpu:gpu_blas_lt", diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc index cf6b4f70f699ba..f784cbf38c3eee 100644 --- a/tensorflow/core/kernels/matmul_op_fused.cc +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -200,6 +200,10 @@ struct LaunchFusedMatMulOp { namespace { #if GOOGLE_CUDA || TF_HIPBLASLT +/* + hipBLASLt support Epilogue: + https://rocm.docs.amd.com/projects/hipBLASLt/en/latest/datatypes.html#hipblasltepilogue-t +*/ StatusOr GetBlasLtEpilogOp( FusedComputationType fusion) { if (fusion == FusedComputationType::kBiasAdd) { @@ -263,7 +267,7 @@ se::blas::AlgorithmConfig AutotuneMatmul( } return algorithm_config; } -#endif +#endif // GOOGLE_CUDA || TF_HIPBLASLT template StatusOr> AutotuneMatMulImpl( @@ -478,6 +482,17 @@ struct LaunchFusedMatMulOp { se::dnn::ActivationMode matmul_activation_mode; bool use_cudnn = false; + +#if !(GOOGLE_CUDA || TF_HIPBLASLT) + use_cudnn = true; +#endif + const auto& cc = stream->parent()->GetDeviceDescription(). + gpu_compute_capability(); + if (auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) { + use_cudnn = !procm->gfx9_mi200_or_later(); + } + + // use_cudnn is for hipblaslt doesn't support yet switch (fusion) { case FusedComputationType::kBiasAddWithGeluExact: matmul_activation_mode = se::dnn::ActivationMode::kGeluExact; @@ -512,15 +527,6 @@ struct LaunchFusedMatMulOp { default: use_cudnn = false; } -#if !(GOOGLE_CUDA || TF_HIPBLASLT) - use_cudnn = true; -#endif - -#if TF_HIPBLASLT - auto cap = stream->GetRocmComputeCapability(); - // as of ROCm 5.5, hipblaslt only supports MI200. - if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") use_cudnn = true; -#endif BlasScratchAllocator scratch_allocator(context); @@ -591,32 +597,31 @@ struct LaunchFusedMatMulOp { epilog_op}; absl::Mutex* pmu; auto plan_and_algorithms_or = - GetPlanAndAlgorithms(stream, matmul_params, &pmu); + BlasLtMatmulPlanCache::GetOrCreate(stream, matmul_params, &pmu); OP_REQUIRES_OK(context, plan_and_algorithms_or.status()); absl::MutexLock lock(pmu); - const auto* plan_and_algorithms = std::move(plan_and_algorithms_or).value(); - const auto& algorithms = plan_and_algorithms->algorithms; - OP_REQUIRES(context, algorithms.size() > 0, + const auto& entry = *plan_and_algorithms_or.value(); + OP_REQUIRES(context, entry.algorithms.size() > 0, errors::InvalidArgument("No matmul algorithm returned!")); auto launch_func = [&](BlasScratchAllocator& scratch_allocator, size_t alg_idx, se::blas::ProfileResult* profile_result) { - return DoBlasLtMatmul(stream, *plan_and_algorithms, a_ptr, b_ptr, c_ptr, - alg_idx, scratch_allocator, bias_ptr, - profile_result); + return BlasLtMatmulPlanCache::ExecuteOnStream( + stream, entry, a_ptr, b_ptr, c_ptr, alg_idx, + scratch_allocator, bias_ptr, profile_result); }; size_t alg_idx = 0; if (use_autotune) { auto algorithm_config = - AutotuneMatmul(algorithms, matmul_params, context, launch_func); + AutotuneMatmul(entry.algorithms, matmul_params, context, launch_func); alg_idx = algorithm_config.algorithm(); } OP_REQUIRES_OK(context, launch_func(scratch_allocator, alg_idx, nullptr)); -#endif +#endif // GOOGLE_CUDA || TF_HIPBLASLT } }; diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 70fc941e80fa15..db9fde9f0e5296 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -601,12 +601,13 @@ struct LaunchBatchMatMul { #if GOOGLE_CUDA || TF_HIPBLASLT static const bool use_autotune = MatmulAutotuneEnable(); bool bCublasLtSupport = true; -#if TF_HIPBLASLT - if (!std::is_same_v) bCublasLtSupport = false; - auto cap = stream->GetRocmComputeCapability(); - // as of ROCm 5.5, hipblaslt only supports MI200. - if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") bCublasLtSupport = false; -#endif + + const auto& cc = stream->parent()->GetDeviceDescription(). + gpu_compute_capability(); + if(auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) { + bCublasLtSupport = procm->gfx9_mi200_or_later(); + } + if (EnableCublasLtGemm() && bCublasLtSupport) { static const int64_t max_scratch_size = GetWorkspaceLimit(1LL << 32); // 4GB by default @@ -636,7 +637,7 @@ struct LaunchBatchMatMul { std::optional max_algorithm_count; if (!use_autotune) max_algorithm_count = 1; absl::Mutex* pmu = nullptr; - auto plan_and_algorithms_or = GetPlanAndAlgorithms( + auto plan_and_algorithms_or = BlasLtMatmulPlanCache::GetOrCreate( stream, matmul_params, &pmu, max_algorithm_count); OP_REQUIRES_OK(context, plan_and_algorithms_or.status()); absl::MutexLock lock(pmu); @@ -659,9 +660,10 @@ struct LaunchBatchMatMul { // scratch space is deallocated between runs. BlasScratchAllocator scratch_allocator(context, max_scratch_size); Status cublas_launch_status = - DoBlasLtMatmul(stream, *plan_and_algorithms, *a_ptrs[0], - *b_ptrs[0], *c_ptrs[0], i, scratch_allocator, - /*bias = */ {}, &profile_result); + BlasLtMatmulPlanCache::ExecuteOnStream(stream, + *plan_and_algorithms, + *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], i, scratch_allocator, + se::DeviceMemoryBase{}, &profile_result); VLOG(4) << " Autotune algorithm " << i << " result: " << profile_result.elapsed_time_in_ms() @@ -701,8 +703,10 @@ struct LaunchBatchMatMul { OP_REQUIRES_OK( context, - DoBlasLtMatmul(stream, *plan_and_algorithms, *a_ptrs[0], *b_ptrs[0], - *c_ptrs[0], algorithm_idx, scratch_allocator)); + BlasLtMatmulPlanCache::ExecuteOnStream(stream, + *plan_and_algorithms, + *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], + algorithm_idx, scratch_allocator, se::DeviceMemoryBase{})); } else { // requires mixed broadcasting const std::vector& a_batch_indices = bcast.x_batch_indices(); const std::vector& b_batch_indices = bcast.y_batch_indices(); diff --git a/tensorflow/core/kernels/matmul_util.cc b/tensorflow/core/kernels/matmul_util.cc index c4be5da2b62ece..8f95e9a9336fe2 100644 --- a/tensorflow/core/kernels/matmul_util.cc +++ b/tensorflow/core/kernels/matmul_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include "xla/status_macros.h" @@ -24,6 +25,8 @@ limitations under the License. #include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/matmul_autotune.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" namespace tensorflow { @@ -44,10 +47,6 @@ int64_t GetWorkspaceLimit(int64_t default_value_in_bytes) { return default_value_in_bytes; } -std::string BlasLtMatmulPlanParams::ToString() const { - return ""; // TODO -} - bool BlasLtMatmulPlanParams::operator==( const BlasLtMatmulPlanParams& other) const { return internal::AsTuple(*this) == internal::AsTuple(other); @@ -55,22 +54,6 @@ bool BlasLtMatmulPlanParams::operator==( namespace { -// Thread-safe map from matmul parameters to their corresponding plan and -// algorithms. -struct BlasLtMatmulPlanMap { - absl::Mutex mu; - - template - auto emplace(Args&&... args) { - absl::MutexLock lock(&mu); - return map_.emplace(std::forward(args)...); - } - - private: - absl::flat_hash_map map_ - ABSL_GUARDED_BY(mu); -}; - int MatmulMaxAutotuneAlgorithmCount() { int64_t value; Status status = @@ -110,9 +93,19 @@ StatusOr GetBlasComputationType( } // namespace -StatusOr GetPlanAndAlgorithms( +/* static */ BlasLtMatmulPlanCache& BlasLtMatmulPlanCache::i(se::Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets different cache instance + static std::deque< BlasLtMatmulPlanCache > meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + return meta[dev_id]; +} + +/* static */ auto BlasLtMatmulPlanCache::GetOrCreate( se::Stream* stream, const BlasLtMatmulPlanParams& params, - absl::Mutex** ppmu, std::optional max_algorithm_count) { + absl::Mutex** ppmu, std::optional max_algorithm_count) -> StatusOr{ static const int64_t max_scratch_size = GetWorkspaceLimit(1LL << 32); // 4GB by default static const int64_t max_autotune_algorithm_count = @@ -120,17 +113,17 @@ StatusOr GetPlanAndAlgorithms( if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count; - static BlasLtMatmulPlanMap plan_map; + auto& self = BlasLtMatmulPlanCache::i(stream); - auto [ptr, inserted] = plan_map.emplace(params, PlanAndAlgorithms{}); + absl::MutexLock lock(self.mutex_.get()); + auto [ptr, inserted] = self.map_.emplace(params, Entry{}); + auto& entry = ptr->second; if (inserted) { TF_ASSIGN_OR_RETURN(auto xlatype, se::gpu::AsXlaPrimitiveType(params.dtype)); TF_ASSIGN_OR_RETURN(auto computation_type, GetBlasComputationType(params.dtype)); - auto scale_type = se::gpu::GetScaleType(params.dtype, computation_type); - // row-major output is now handled automatically by blas-lt API constexpr auto kRowMajor = se::gpu::MatrixLayout::Order::kRowMajor; @@ -173,19 +166,42 @@ StatusOr GetPlanAndAlgorithms( .compute_type = computation_type, }; - TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( + TF_ASSIGN_OR_RETURN(entry.plan, se::gpu::BlasLt::GetMatmulPlan( stream, cfg, params.epilogue)); TF_ASSIGN_OR_RETURN( - auto algorithms, - plan->GetAlgorithms(*max_algorithm_count, max_scratch_size)); - - ptr->second = {std::move(plan), std::move(algorithms), scale_type}; + entry.algorithms, + entry.plan->GetAlgorithms(*max_algorithm_count, max_scratch_size)); } - *ppmu = &plan_map.mu; - return &ptr->second; + *ppmu = self.mutex_.get(); + return &entry; } +/*static */ Status BlasLtMatmulPlanCache::ExecuteOnStream(se::Stream* stream, + const Entry& entry, + const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b, + se::DeviceMemoryBase& c, + size_t algorithm_idx, + se::ScratchAllocator& scratch_allocator, + const se::DeviceMemoryBase& bias, + se::blas::ProfileResult* profile_result) { + + return entry.plan->ExecuteOnStream( + stream, a, b, c, c, + bias, // bias_buffer + se::DeviceMemoryBase{}, // aux_buffer + se::DeviceMemoryBase{}, // a_scale_buffer + se::DeviceMemoryBase{}, // b_scale_buffer + se::DeviceMemoryBase{}, // c_scale_buffer + se::DeviceMemoryBase{}, // d_scale_buffer + se::DeviceMemoryBase{}, // d_amax_buffer + entry.algorithms[algorithm_idx], + scratch_allocator, + profile_result); +} + + } // namespace tensorflow #endif \ No newline at end of file diff --git a/tensorflow/core/kernels/matmul_util.h b/tensorflow/core/kernels/matmul_util.h index 371964424eff85..dbf85eab41242c 100644 --- a/tensorflow/core/kernels/matmul_util.h +++ b/tensorflow/core/kernels/matmul_util.h @@ -21,7 +21,7 @@ limitations under the License. #if GOOGLE_CUDA || TF_HIPBLASLT -#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "tensorflow/core/framework/types.h" @@ -35,7 +35,8 @@ namespace tensorflow { int64_t GetWorkspaceLimit(int64_t default_value_in_bytes); struct BlasLtMatmulPlanParams { - std::string ToString() const; + + std::string ToString() const { return "NOP"; } bool operator==(const BlasLtMatmulPlanParams& other) const; se::blas::DataType dtype; @@ -50,12 +51,6 @@ struct BlasLtMatmulPlanParams { se::gpu::BlasLt::Epilogue epilogue = se::gpu::BlasLt::Epilogue::kDefault; }; -struct PlanAndAlgorithms { - se::gpu::BlasLt::MatmulPlanPtr plan; - std::vector algorithms; - se::blas::DataType scale_type; // this is needed for half / bf16 treatment -}; - namespace internal { inline auto AsTuple(const BlasLtMatmulPlanParams& p) { @@ -71,37 +66,42 @@ H AbslHashValue(H h, const BlasLtMatmulPlanParams& params) { return H::combine(std::move(h), internal::AsTuple(params)); } -StatusOr GetPlanAndAlgorithms( +struct BlasLtMatmulPlanCache { + struct Entry { + se::gpu::BlasLt::MatmulPlanPtr plan; + std::vector< se::gpu::BlasLt::MatmulAlgorithm > algorithms; + }; + + static StatusOr GetOrCreate( se::Stream* stream, const BlasLtMatmulPlanParams& params, absl::Mutex** pmu, - std::optional max_algorithm_count = std::nullopt); - -template -Status DoBlasLtMatmul(se::Stream* stream, const PlanAndAlgorithms& paa, - const se::DeviceMemory& a, - const se::DeviceMemory& b, se::DeviceMemory& c, - size_t alg_idx, se::ScratchAllocator& scratch_allocator, - const se::DeviceMemory& bias = {}, - se::blas::ProfileResult* profile_result = nullptr) { - se::DeviceMemory aux{}; // We don't use the auxilary buffers. - const auto& algorithm = paa.algorithms[alg_idx]; - - // The scale type may be f32 if the data type is f16 and bf16. - if constexpr (std::is_same_v || - std::is_same_v) { - if (paa.scale_type == se::blas::DataType::kFloat) { - return paa.plan->DoMatmul(stream, se::HostOrDeviceScalar(1.0), b, - a, se::HostOrDeviceScalar(0.0), c, c, - algorithm, scratch_allocator, bias, aux, - profile_result); - } + std::optional max_algorithm_count = std::nullopt + ); + + // helper function for plan execution + static Status ExecuteOnStream(se::Stream* stream, + const Entry& entry, + const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b, + se::DeviceMemoryBase& c, + size_t algorithm_idx, + se::ScratchAllocator& scratch_allocator, + const se::DeviceMemoryBase& bias, + se::blas::ProfileResult* profile_result = nullptr); + + BlasLtMatmulPlanCache() : mutex_(new absl::Mutex) { } - return paa.plan->DoMatmul(stream, se::HostOrDeviceScalar(T(1.0)), b, a, - se::HostOrDeviceScalar(T(0.0)), c, c, algorithm, - scratch_allocator, bias, aux, profile_result); -} + +private: + static BlasLtMatmulPlanCache& i(se::Stream *stream); + + std::unique_ptr mutex_; + absl::node_hash_map map_ + ABSL_GUARDED_BY(mutex_); + +}; // BlasLtMatmulPlanCache } // namespace tensorflow -#endif +#endif // GOOGLE_CUDA || TF_HIPBLASLT #endif // TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_ diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc b/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc index b8e5a8e8d1e662..531ad8d2f605fb 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc @@ -103,11 +103,16 @@ __global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a, #endif // GOOGLE_CUDA #if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 + __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a, __hip_fp8_storage_t* buffer_b, float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +// NOTE: according to amd_hip_fp8.h, GFX1200 and GFX1201 support ocp __hip_fp8_e4m3 +// but not __hip_fp8_e4m3_fnuz + int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= buffer_length) return; __hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8; @@ -123,6 +128,10 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a, if (rel_error > rel_error_threshold || isnan(rel_error)) atomicAdd(mismatch_count, 1); +#else + // on unsupported architectures, this should not / cannot be used! + atomicAdd(mismatch_count, 1); +#endif } __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, @@ -130,6 +139,7 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= buffer_length) return; __hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8; @@ -145,7 +155,12 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, if (rel_error > rel_error_threshold || isnan(rel_error)) atomicAdd(mismatch_count, 1); +#else + // on unsupported architectures, this should not / cannot be used! + atomicAdd(mismatch_count, 1); +#endif } + #endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 __global__ void xla_fp16_comparison(__half* buffer_a, __half* buffer_b, diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index c5e651b99f877f..a8635e3455677c 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -720,7 +720,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(gemm_config), + instr, std::move(gemm_config), blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, workspace_buffer); AddThunkToThunkSequence(std::move(thunk)); @@ -810,7 +810,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(gemm_config), + instr, std::move(gemm_config), blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, workspace_buffer); AddThunkToThunkSequence(std::move(thunk)); diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 331e272353f836..8f43585cce11da 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -728,6 +728,8 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/autotuning:autotuner_util", + "//xla/stream_executor", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_blas_lt", diff --git a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc index 8cdcf39773278c..8eb2efb2432e72 100644 --- a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" @@ -35,8 +36,44 @@ limitations under the License. namespace xla { namespace gpu { +struct MatmulPlanCache { + + static MatmulPlanCache& i(const se::Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets different cache instance + static std::vector< std::unique_ptr< MatmulPlanCache > > meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + auto& res = meta[dev_id]; + if (!res) res.reset(new MatmulPlanCache()); + return *res; + } + + template < class Func > + StatusOr + GetOrCreate(const std::string& key, Func&& create) { + // each GPU has a different mutex => hence different GPU instances can + // create matmul plans in parallel + absl::MutexLock lock(mutex_.get()); + auto res = map_.emplace(key, se::gpu::BlasLt::MatmulPlanPtr{}); + if(res.second) { // new entry inserted + TF_ASSIGN_OR_RETURN(res.first->second, create()); + } + return res.first->second.get(); + } + +private: + MatmulPlanCache() : mutex_(std::make_unique< absl::Mutex >()) { } + +private: + std::unique_ptr< absl::Mutex > mutex_; + absl::flat_hash_map map_; +}; + + CublasLtMatmulThunk::CublasLtMatmulThunk( - ThunkInfo thunk_info, GemmConfig gemm_config, + const HloInstruction *instr, GemmConfig gemm_config, se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, @@ -45,7 +82,7 @@ CublasLtMatmulThunk::CublasLtMatmulThunk( BufferAllocation::Slice c_scale, BufferAllocation::Slice d_scale, BufferAllocation::Slice d_amax, std::optional workspace_buffer) - : Thunk(Kind::kCublasLtMatmul, thunk_info), + : Thunk(Kind::kCublasLtMatmul, Thunk::ThunkInfo::WithProfileAnnotation(instr)), gemm_config_(std::move(gemm_config)), epilogue_(epilogue), algorithm_idx_(algorithm_idx), @@ -60,18 +97,18 @@ CublasLtMatmulThunk::CublasLtMatmulThunk( c_scale_buffer_(c_scale), d_scale_buffer_(d_scale), d_amax_buffer_(d_amax), - workspace_buffer_(workspace_buffer) {} + workspace_buffer_(workspace_buffer) { + + canonical_hlo_ = xla::gpu::AutotuneCacheKey("nope", *instr).GetHlo(); +} absl::Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { - TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(params.stream)); - TF_ASSIGN_OR_RETURN( - auto algorithm, - GetMatmulAlgorithm(plan, workspace_buffer_.has_value() - ? workspace_buffer_.value().size() - : 0)); + TF_ASSIGN_OR_RETURN(auto *plan, GetCachedMatmulPlan(params)); + + VLOG(2) << params.stream->parent()->device_ordinal() << + ": cublas_lt_matmul for: " << canonical_hlo_; - VLOG(3) << "Running cublas_lt matmul thunk"; const BufferAllocations& allocs = *params.buffer_allocations; se::DeviceMemoryBase bias, a_scale, b_scale, c_scale, d_scale, d_amax; @@ -103,47 +140,39 @@ absl::Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { if (workspace_buffer_.has_value()) { workspace = allocs.GetDeviceAddress(workspace_buffer_.value()); } + return plan->ExecuteOnStream( params.stream, allocs.GetDeviceAddress(a_buffer_), allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_), allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, - d_scale, d_amax, algorithm, workspace); + d_scale, d_amax, {}, workspace); } -absl::StatusOr CublasLtMatmulThunk::GetMatmulPlan( - const stream_executor::Stream* stream) { - { - absl::MutexLock lock(&matmul_plans_cache_mutex_); - auto it = matmul_plans_cache_.find(stream); - if (it != matmul_plans_cache_.end()) return it->second.get(); - } - TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( - stream, gemm_config_, epilogue_)); - absl::MutexLock lock(&matmul_plans_cache_mutex_); - auto [it, _] = matmul_plans_cache_.emplace(stream, std::move(plan)); - return it->second.get(); -} - -absl::StatusOr -CublasLtMatmulThunk::GetMatmulAlgorithm(const se::gpu::BlasLt::MatmulPlan* plan, - int64_t max_workspace) { - { - absl::MutexLock lock(&matmul_algorithm_cache_mutex_); - auto it = matmul_algorithm_cache_.find(plan); - if (it != matmul_algorithm_cache_.end()) return it->second; - } - TF_ASSIGN_OR_RETURN( - auto algorithms, - plan->GetAlgorithms(/*max_algorithm_count*/ 128, - /*max_workspace_size*/ max_workspace)); - TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); - - absl::MutexLock lock(&matmul_algorithm_cache_mutex_); - auto [it, _] = - matmul_algorithm_cache_.emplace(plan, algorithms[algorithm_idx_]); - return it->second; +auto CublasLtMatmulThunk::GetCachedMatmulPlan( + const ExecuteParams& params) -> absl::StatusOr { + + auto& cache = MatmulPlanCache::i(params.stream); + + auto create = [&]() -> StatusOr { + VLOG(2) << this << ": Adding new MatmulPlan for stream: " << params.stream << + " instr: " << canonical_hlo_; + + TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( + params.stream, gemm_config_, epilogue_)); + + int64_t max_workspace = workspace_buffer_.has_value() + ? workspace_buffer_.value().size() : 0; + int64_t num_algorithms = algorithm_idx_ == se::blas::kDefaultAlgorithm ? + 1 : 128; + TF_ASSIGN_OR_RETURN(auto algorithms, + plan->GetAlgorithms(num_algorithms, max_workspace)); + + TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[algorithm_idx_])); + return std::move(plan); + }; + return cache.GetOrCreate(canonical_hlo_, create); } absl::Status CublasLtMatmulThunk::Initialize(const InitializeParams& params) { diff --git a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h index aa114bd3ee93fd..5602a22f6fd1fe 100644 --- a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h @@ -38,7 +38,7 @@ namespace gpu { class CublasLtMatmulThunk : public Thunk { public: CublasLtMatmulThunk( - ThunkInfo thunk_info, GemmConfig gemm_config, + const HloInstruction *instr, GemmConfig gemm_config, se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, @@ -74,24 +74,13 @@ class CublasLtMatmulThunk : public Thunk { } private: - absl::StatusOr GetMatmulPlan( - const stream_executor::Stream* stream); - absl::StatusOr GetMatmulAlgorithm( - const se::gpu::BlasLt::MatmulPlan* plan, int64_t max_workspace); - - absl::Mutex matmul_plans_cache_mutex_; - absl::flat_hash_map - matmul_plans_cache_ ABSL_GUARDED_BY(matmul_plans_cache_mutex_); - - absl::Mutex matmul_algorithm_cache_mutex_; - absl::flat_hash_map - matmul_algorithm_cache_ ABSL_GUARDED_BY(matmul_algorithm_cache_mutex_); + absl::StatusOr GetCachedMatmulPlan( + const ExecuteParams& params); GemmConfig gemm_config_; se::gpu::BlasLt::Epilogue epilogue_; int64_t algorithm_idx_; + std::string canonical_hlo_; BufferAllocation::Slice a_buffer_; BufferAllocation::Slice b_buffer_; BufferAllocation::Slice c_buffer_; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h index 20caccbe18e62e..c2ebba7e87296b 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h @@ -268,6 +268,9 @@ struct BlasLt { size_t max_algorithm_count = 128, size_t max_workspace_size = 1ll << 32) const = 0; + // Algorithm needs to be set before calling ExecuteOnStream function + virtual absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) const = 0; + virtual ~MatmulPlan() {} protected: diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index 9b87f9e4a4e4c0..5389471363f590 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -383,10 +383,15 @@ absl::Status BlasLt::MatmulPlan::ValidateInputs( return absl::OkStatus(); } +absl::Status BlasLt::MatmulPlan::SetAlgorithm(const MatmulAlgorithm& algorithm) const { + algorithm_ = algorithm; + return absl::OkStatus(); +} + absl::Status BlasLt::MatmulPlan::DoMatmul( Stream* stream, const void* alpha, DeviceMemoryBase a, DeviceMemoryBase b, const void* beta, DeviceMemoryBase c, DeviceMemoryBase d, - const MatmulAlgorithm& algorithm, DeviceMemoryBase bias, + const MatmulAlgorithm& Xalgorithm, DeviceMemoryBase bias, DeviceMemoryBase aux, DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, DeviceMemoryBase d_amax, std::optional workspace, @@ -403,6 +408,8 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( profile_result->warmup_run_executed())); } + auto algorithm = algorithm_.has_value() ? *algorithm_ : Xalgorithm; + void* workspace_addr = nullptr; uint64_t workspace_size = 0; if (workspace.has_value()) { @@ -613,4 +620,4 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream( } // namespace stream_executor -#endif // TF_HIPBLASLT +#endif // TF_HIPBLASLT \ No newline at end of file diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h index b781e2848cd479..b53ce7f8e5e1a6 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h @@ -114,6 +114,8 @@ class BlasLt : public gpu::BlasLt { absl::StatusOr> GetAlgorithms( size_t max_algorithm_count, size_t max_workspace_size) const override; + absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) const override; + protected: absl::Status ValidateInputs(blas::DataType scale_type, bool alpha_on_device, bool beta_on_device, blas::DataType A_type, @@ -143,6 +145,7 @@ class BlasLt : public gpu::BlasLt { xla::complex128 alpha_; double beta_; bool must_swap_operands_; + mutable std::optional< MatmulAlgorithm > algorithm_; // selected algorithm }; // class MatmulPlan explicit BlasLt(StreamExecutor* parent) diff --git a/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h b/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h index 09cac948f0d185..a92e45cf2716b0 100644 --- a/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h @@ -17,6 +17,8 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_ #define XLA_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_ +#define __HIP_DISABLE_CPP_FUNCTIONS__ + #include "rocm/rocm_config.h" #if TF_HIPBLASLT