Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3630,7 +3630,7 @@ cc_library(
features = ["-layering_check"],
local_defines = if_cuda(["GOOGLE_CUDA=1"]) + if_rocm(["TENSORFLOW_USE_ROCM=1"]),
deps = if_cuda_or_rocm([
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:node_hash_map",
"@local_xla//xla:status_macros",
"@local_xla//xla:xla_data_proto_cc",
"@local_xla//xla/stream_executor/gpu:gpu_blas_lt",
Expand Down
43 changes: 24 additions & 19 deletions tensorflow/core/kernels/matmul_op_fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ struct LaunchFusedMatMulOp<CPUDevice, T> {
namespace {

#if GOOGLE_CUDA || TF_HIPBLASLT
/*
hipBLASLt support Epilogue:
https://rocm.docs.amd.com/projects/hipBLASLt/en/latest/datatypes.html#hipblasltepilogue-t
*/
StatusOr<se::gpu::BlasLt::Epilogue> GetBlasLtEpilogOp(
FusedComputationType fusion) {
if (fusion == FusedComputationType::kBiasAdd) {
Expand Down Expand Up @@ -263,7 +267,7 @@ se::blas::AlgorithmConfig AutotuneMatmul(
}
return algorithm_config;
}
#endif
#endif // GOOGLE_CUDA || TF_HIPBLASLT

template <typename LaunchFunc, typename Sig>
StatusOr<std::vector<xla::AutotuneResult>> AutotuneMatMulImpl(
Expand Down Expand Up @@ -478,6 +482,17 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {

se::dnn::ActivationMode matmul_activation_mode;
bool use_cudnn = false;

#if !(GOOGLE_CUDA || TF_HIPBLASLT)
use_cudnn = true;
#endif
const auto& cc = stream->parent()->GetDeviceDescription().
gpu_compute_capability();
if (auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) {
use_cudnn = !procm->gfx9_mi200_or_later();
}

// use_cudnn is for hipblaslt doesn't support yet
switch (fusion) {
case FusedComputationType::kBiasAddWithGeluExact:
matmul_activation_mode = se::dnn::ActivationMode::kGeluExact;
Expand Down Expand Up @@ -512,15 +527,6 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {
default:
use_cudnn = false;
}
#if !(GOOGLE_CUDA || TF_HIPBLASLT)
use_cudnn = true;
#endif

#if TF_HIPBLASLT
auto cap = stream->GetRocmComputeCapability();
// as of ROCm 5.5, hipblaslt only supports MI200.
if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") use_cudnn = true;
#endif

BlasScratchAllocator scratch_allocator(context);

Expand Down Expand Up @@ -591,32 +597,31 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {
epilog_op};
absl::Mutex* pmu;
auto plan_and_algorithms_or =
GetPlanAndAlgorithms(stream, matmul_params, &pmu);
BlasLtMatmulPlanCache::GetOrCreate(stream, matmul_params, &pmu);
OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
absl::MutexLock lock(pmu);
const auto* plan_and_algorithms = std::move(plan_and_algorithms_or).value();
const auto& algorithms = plan_and_algorithms->algorithms;
OP_REQUIRES(context, algorithms.size() > 0,
const auto& entry = *plan_and_algorithms_or.value();
OP_REQUIRES(context, entry.algorithms.size() > 0,
errors::InvalidArgument("No matmul algorithm returned!"));

auto launch_func = [&](BlasScratchAllocator& scratch_allocator,
size_t alg_idx,
se::blas::ProfileResult* profile_result) {
return DoBlasLtMatmul(stream, *plan_and_algorithms, a_ptr, b_ptr, c_ptr,
alg_idx, scratch_allocator, bias_ptr,
profile_result);
return BlasLtMatmulPlanCache::ExecuteOnStream(
stream, entry, a_ptr, b_ptr, c_ptr, alg_idx,
scratch_allocator, bias_ptr, profile_result);
};

size_t alg_idx = 0;
if (use_autotune) {
auto algorithm_config =
AutotuneMatmul(algorithms, matmul_params, context, launch_func);
AutotuneMatmul(entry.algorithms, matmul_params, context, launch_func);

alg_idx = algorithm_config.algorithm();
}

OP_REQUIRES_OK(context, launch_func(scratch_allocator, alg_idx, nullptr));
#endif
#endif // GOOGLE_CUDA || TF_HIPBLASLT
}
};

Expand Down
28 changes: 16 additions & 12 deletions tensorflow/core/kernels/matmul_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,12 +601,13 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
#if GOOGLE_CUDA || TF_HIPBLASLT
static const bool use_autotune = MatmulAutotuneEnable();
bool bCublasLtSupport = true;
#if TF_HIPBLASLT
if (!std::is_same_v<Scalar, float>) bCublasLtSupport = false;
auto cap = stream->GetRocmComputeCapability();
// as of ROCm 5.5, hipblaslt only supports MI200.
if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") bCublasLtSupport = false;
#endif

const auto& cc = stream->parent()->GetDeviceDescription().
gpu_compute_capability();
if(auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) {
bCublasLtSupport = procm->gfx9_mi200_or_later();
}

if (EnableCublasLtGemm() && bCublasLtSupport) {
static const int64_t max_scratch_size =
GetWorkspaceLimit(1LL << 32); // 4GB by default
Expand Down Expand Up @@ -636,7 +637,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
std::optional<int> max_algorithm_count;
if (!use_autotune) max_algorithm_count = 1;
absl::Mutex* pmu = nullptr;
auto plan_and_algorithms_or = GetPlanAndAlgorithms(
auto plan_and_algorithms_or = BlasLtMatmulPlanCache::GetOrCreate(
stream, matmul_params, &pmu, max_algorithm_count);
OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
absl::MutexLock lock(pmu);
Expand All @@ -659,9 +660,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
// scratch space is deallocated between runs.
BlasScratchAllocator scratch_allocator(context, max_scratch_size);
Status cublas_launch_status =
DoBlasLtMatmul(stream, *plan_and_algorithms, *a_ptrs[0],
*b_ptrs[0], *c_ptrs[0], i, scratch_allocator,
/*bias = */ {}, &profile_result);
BlasLtMatmulPlanCache::ExecuteOnStream(stream,
*plan_and_algorithms,
*a_ptrs[0], *b_ptrs[0], *c_ptrs[0], i, scratch_allocator,
se::DeviceMemoryBase{}, &profile_result);

VLOG(4) << " Autotune algorithm " << i
<< " result: " << profile_result.elapsed_time_in_ms()
Expand Down Expand Up @@ -701,8 +703,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {

OP_REQUIRES_OK(
context,
DoBlasLtMatmul(stream, *plan_and_algorithms, *a_ptrs[0], *b_ptrs[0],
*c_ptrs[0], algorithm_idx, scratch_allocator));
BlasLtMatmulPlanCache::ExecuteOnStream(stream,
*plan_and_algorithms,
*a_ptrs[0], *b_ptrs[0], *c_ptrs[0],
algorithm_idx, scratch_allocator, se::DeviceMemoryBase{}));
} else { // requires mixed broadcasting
const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices();
const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices();
Expand Down
82 changes: 49 additions & 33 deletions tensorflow/core/kernels/matmul_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

#include <optional>
#include <string>
#include <deque>
#include <utility>

#include "xla/status_macros.h"
Expand All @@ -24,6 +25,8 @@ limitations under the License.
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/matmul_autotune.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"

namespace tensorflow {

Expand All @@ -44,33 +47,13 @@ int64_t GetWorkspaceLimit(int64_t default_value_in_bytes) {
return default_value_in_bytes;
}

std::string BlasLtMatmulPlanParams::ToString() const {
return ""; // TODO
}

bool BlasLtMatmulPlanParams::operator==(
const BlasLtMatmulPlanParams& other) const {
return internal::AsTuple(*this) == internal::AsTuple(other);
}

namespace {

// Thread-safe map from matmul parameters to their corresponding plan and
// algorithms.
struct BlasLtMatmulPlanMap {
absl::Mutex mu;

template <class... Args>
auto emplace(Args&&... args) {
absl::MutexLock lock(&mu);
return map_.emplace(std::forward<Args>(args)...);
}

private:
absl::flat_hash_map<BlasLtMatmulPlanParams, PlanAndAlgorithms> map_
ABSL_GUARDED_BY(mu);
};

int MatmulMaxAutotuneAlgorithmCount() {
int64_t value;
Status status =
Expand Down Expand Up @@ -110,27 +93,37 @@ StatusOr<se::blas::ComputationType> GetBlasComputationType(

} // namespace

StatusOr<const PlanAndAlgorithms*> GetPlanAndAlgorithms(
/* static */ BlasLtMatmulPlanCache& BlasLtMatmulPlanCache::i(se::Stream *stream) {
static absl::Mutex m(absl::kConstInit);
// Each GPU gets different cache instance
static std::deque< BlasLtMatmulPlanCache > meta(8);
absl::MutexLock lock(&m);
size_t dev_id = stream->parent()->device_ordinal();
if (dev_id >= meta.size()) meta.resize(dev_id + 1);
return meta[dev_id];
}

/* static */ auto BlasLtMatmulPlanCache::GetOrCreate(
se::Stream* stream, const BlasLtMatmulPlanParams& params,
absl::Mutex** ppmu, std::optional<int> max_algorithm_count) {
absl::Mutex** ppmu, std::optional<int> max_algorithm_count) -> StatusOr<const Entry *>{
static const int64_t max_scratch_size =
GetWorkspaceLimit(1LL << 32); // 4GB by default
static const int64_t max_autotune_algorithm_count =
MatmulMaxAutotuneAlgorithmCount();

if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count;

static BlasLtMatmulPlanMap plan_map;
auto& self = BlasLtMatmulPlanCache::i(stream);

auto [ptr, inserted] = plan_map.emplace(params, PlanAndAlgorithms{});
absl::MutexLock lock(self.mutex_.get());
auto [ptr, inserted] = self.map_.emplace(params, Entry{});
auto& entry = ptr->second;
if (inserted) {
TF_ASSIGN_OR_RETURN(auto xlatype,
se::gpu::AsXlaPrimitiveType(params.dtype));
TF_ASSIGN_OR_RETURN(auto computation_type,
GetBlasComputationType(params.dtype));

auto scale_type = se::gpu::GetScaleType(params.dtype, computation_type);

// row-major output is now handled automatically by blas-lt API
constexpr auto kRowMajor = se::gpu::MatrixLayout::Order::kRowMajor;

Expand Down Expand Up @@ -173,19 +166,42 @@ StatusOr<const PlanAndAlgorithms*> GetPlanAndAlgorithms(
.compute_type = computation_type,
};

TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan(
TF_ASSIGN_OR_RETURN(entry.plan, se::gpu::BlasLt::GetMatmulPlan(
stream, cfg, params.epilogue));

TF_ASSIGN_OR_RETURN(
auto algorithms,
plan->GetAlgorithms(*max_algorithm_count, max_scratch_size));

ptr->second = {std::move(plan), std::move(algorithms), scale_type};
entry.algorithms,
entry.plan->GetAlgorithms(*max_algorithm_count, max_scratch_size));
}
*ppmu = &plan_map.mu;
return &ptr->second;
*ppmu = self.mutex_.get();
return &entry;
}

/*static */ Status BlasLtMatmulPlanCache::ExecuteOnStream(se::Stream* stream,
const Entry& entry,
const se::DeviceMemoryBase& a,
const se::DeviceMemoryBase& b,
se::DeviceMemoryBase& c,
size_t algorithm_idx,
se::ScratchAllocator& scratch_allocator,
const se::DeviceMemoryBase& bias,
se::blas::ProfileResult* profile_result) {

return entry.plan->ExecuteOnStream(
stream, a, b, c, c,
bias, // bias_buffer
se::DeviceMemoryBase{}, // aux_buffer
se::DeviceMemoryBase{}, // a_scale_buffer
se::DeviceMemoryBase{}, // b_scale_buffer
se::DeviceMemoryBase{}, // c_scale_buffer
se::DeviceMemoryBase{}, // d_scale_buffer
se::DeviceMemoryBase{}, // d_amax_buffer
entry.algorithms[algorithm_idx],
scratch_allocator,
profile_result);
}


} // namespace tensorflow

#endif
Loading