From 5a69fe0919e989f46b5fcc2a12d8f534afc6ec57 Mon Sep 17 00:00:00 2001 From: Arash Pakbin Date: Thu, 29 May 2025 21:10:45 +0000 Subject: [PATCH] resolve conflicts --- aten/src/ATen/miopen/Descriptors.h | 58 ++++++++++++++++++------------ aten/src/ATen/miopen/Handle.h | 5 ++- aten/src/ATen/miopen/Types.h | 6 ++-- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/aten/src/ATen/miopen/Descriptors.h b/aten/src/ATen/miopen/Descriptors.h index 6496fd1727f58..d886cb81a89e2 100644 --- a/aten/src/ATen/miopen/Descriptors.h +++ b/aten/src/ATen/miopen/Descriptors.h @@ -38,9 +38,9 @@ struct DescriptorDeleter { // initialized the first time you call set() or any other initializing // function. template -class Descriptor -{ -public: +// NOLINTNEXTLINE(bugprone-exception-escape) +class TORCH_CUDA_CPP_API Descriptor { + public: // Use desc() to access the underlying descriptor pointer in // a read-only fashion. Most client code should use this. // If the descriptor was never initialized, this will return @@ -56,7 +56,7 @@ class Descriptor protected: void init() { if (desc_ == nullptr) { - T* raw_desc; + T* raw_desc = nullptr; MIOPEN_CHECK(ctor(&raw_desc)); desc_.reset(raw_desc); } @@ -65,13 +65,12 @@ class Descriptor std::unique_ptr> desc_; }; -class TORCH_CUDA_CPP_API TensorDescriptor - : public Descriptor -{ -public: - TensorDescriptor() {} +class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< + miopenTensorDescriptor, + &miopenCreateTensorDescriptor, + &miopenDestroyTensorDescriptor> { + public: + TensorDescriptor() = default; explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) { set(t, pad); } @@ -89,11 +88,10 @@ class TORCH_CUDA_CPP_API TensorDescriptor std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d); -class FilterDescriptor - : public Descriptor -{ +class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor< + miopenTensorDescriptor, + &miopenCreateTensorDescriptor, + &miopenDestroyTensorDescriptor> { public: void set(const at::Tensor &t, int64_t pad = 0) { set(t, at::MemoryFormat::Contiguous, pad); @@ -107,11 +105,11 @@ class FilterDescriptor } }; -struct ConvolutionDescriptor - : public Descriptor -{ +struct TORCH_CUDA_CPP_API ConvolutionDescriptor + : public Descriptor< + miopenConvolutionDescriptor, + &miopenCreateConvolutionDescriptor, + &miopenDestroyConvolutionDescriptor> { void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool benchmark, bool deterministic) { MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode)); MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups)); @@ -122,8 +120,24 @@ struct ConvolutionDescriptor } }; +// NOLINTNEXTLINE(bugprone-exception-escape) +struct TORCH_CUDA_CPP_API DropoutDescriptor + : public Descriptor< + miopenDropoutDescriptor, + &miopenCreateDropoutDescriptor, + &miopenDestroyDropoutDescriptor> { + void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, + unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { + MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); + } + + void restore(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, + unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { + MIOPEN_CHECK(miopenRestoreDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); + } +}; -struct RNNDescriptor +struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor diff --git a/aten/src/ATen/miopen/Handle.h b/aten/src/ATen/miopen/Handle.h index 8747ee6daef82..331c449777623 100644 --- a/aten/src/ATen/miopen/Handle.h +++ b/aten/src/ATen/miopen/Handle.h @@ -3,8 +3,7 @@ #include #include -namespace at { namespace native { +namespace at::native { TORCH_CUDA_CPP_API miopenHandle_t getMiopenHandle(); - -}} // namespace +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/miopen/Types.h b/aten/src/ATen/miopen/Types.h index c098e4e5baa87..0a8a1a952e2e2 100644 --- a/aten/src/ATen/miopen/Types.h +++ b/aten/src/ATen/miopen/Types.h @@ -1,13 +1,13 @@ #pragma once -#include #include +#include #include -namespace at { namespace native { +namespace at::native { TORCH_CUDA_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor); int64_t miopen_version(); -}} // namespace at::miopen +} // namespace at::native