From 11b9f5f96d1276b09896f68535e87825e9da7980 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Thu, 28 Oct 2021 11:01:59 +0800 Subject: [PATCH] [Cherry-pick]FFT function enhancements and bugfixes (#36537) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update fft api path (#36219) * update fft api path * add sample code for ihfft2 Co-authored-by: chenfeiyu * fix fft axis (#36321) fix: `-1` is used when fft's axis is `0` * use unified external error message for cufft api (#36114) * fft: modify sample code result (#36325) * dynamic load mkl as a fft backend when it is avaialble and requested (#36414) * add rocm support for fft api (#36415) * move signal apis * move fft and signal API path (#2) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py * fix typos in signal.py (#3) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py * fix typos * disable Cache when CUFFT_VERSION >= 10200 (#4) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py * fix typos * Add LRUCache for fft plans * add LRUCache for cuff and hipfft (#5) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py * fix typos * WIP: add cache * delete move constructor and operator= for CuFFTHandle and FFTConfig * remove log from CuFFTHandle and FFTConfig * add lrucache for fft rocm backend * disable LRUCache when CUFFT_VERSION >= 10200 * disbale copy and move for hipFFTHandle; format code Co-authored-by: Xiaoxu Chen * remove debug message of cufftHandler * roll_op: support Tensor as input for shifts (#36727) * fix fftshift/ifftshift on static mode * update roll_op version * add more test cases for fftshift/ifftshift Co-authored-by: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Co-authored-by: chenfeiyu Co-authored-by: LJQ❤️ <33169170+lijiaqi0612@users.noreply.github.com> --- cmake/third_party.cmake | 4 +- paddle/fluid/operators/CMakeLists.txt | 16 +- paddle/fluid/operators/roll_op.cc | 52 +- paddle/fluid/operators/roll_op.cu | 20 + paddle/fluid/operators/roll_op.h | 17 + paddle/fluid/operators/spectral_helper.h | 466 +++++++++++++ paddle/fluid/operators/spectral_op.cc | 113 ++- paddle/fluid/operators/spectral_op.cu | 642 +++++++----------- paddle/fluid/platform/dynload/CMakeLists.txt | 8 +- .../fluid/platform/dynload/dynamic_loader.cc | 26 + .../fluid/platform/dynload/dynamic_loader.h | 2 + paddle/fluid/platform/dynload/hipfft.cc | 30 + paddle/fluid/platform/dynload/hipfft.h | 124 ++++ paddle/fluid/platform/dynload/mklrt.cc | 51 ++ paddle/fluid/platform/dynload/mklrt.h | 80 +++ paddle/fluid/platform/enforce.h | 24 + paddle/fluid/platform/enforce_test.cc | 26 +- paddle/fluid/platform/external_error.proto | 1 + python/paddle/__init__.py | 3 +- python/paddle/{tensor => }/fft.py | 87 ++- .../fluid/tests/unittests/fft/test_fft.py | 10 +- .../fluid/tests/unittests/test_roll_op.py | 28 + .../fluid/tests/unittests/test_signal.py | 20 +- python/paddle/{tensor => }/signal.py | 26 +- python/paddle/tensor/__init__.py | 2 - python/paddle/tensor/manipulation.py | 23 +- tools/externalError/README.md | 30 +- tools/externalError/spider.py | 29 +- tools/externalError/start.sh | 2 +- 29 files changed, 1413 insertions(+), 549 deletions(-) create mode 100644 paddle/fluid/operators/spectral_helper.h create mode 100644 paddle/fluid/platform/dynload/hipfft.cc create mode 100644 paddle/fluid/platform/dynload/hipfft.h create mode 100644 paddle/fluid/platform/dynload/mklrt.cc create mode 100644 paddle/fluid/platform/dynload/mklrt.h rename python/paddle/{tensor => }/fft.py (97%) rename python/paddle/{tensor => }/signal.py (97%) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 29a5587a07134..d45b5e07bb8f3 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -255,8 +255,8 @@ if(WITH_GPU) include(external/cub) # download cub list(APPEND third_party_deps extern_cub) endif() - set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg.tar.gz" CACHE STRING "" FORCE) - file_download_and_uncompress(${URL} "externalError" MD5 061f3b7895aadcbe2c3ed592590f8b10) # download file externalErrorMsg.tar.gz + set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg_20210928.tar.gz" CACHE STRING "" FORCE) + file_download_and_uncompress(${URL} "externalError" MD5 a712a49384e77ca216ad866712f7cafa) # download file externalErrorMsg.tar.gz if(WITH_TESTING) # copy externalErrorMsg.pb, just for unittest can get error message correctly. set(SRC_DIR ${THIRD_PARTY_PATH}/externalError/data) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 50b83970ab933..937bfea3a59ef 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -105,10 +105,20 @@ else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() -if (WITH_GPU AND (NOT WITH_ROCM)) - op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS}) +if (WITH_GPU OR WITH_ROCM) + if (MKL_FOUND AND WITH_ONEMKL) + op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda dynload_mklrt ${OP_HEADER_DEPS}) + target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE}) + else() + op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS}) + endif() else() - op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS}) + if (MKL_FOUND AND WITH_ONEMKL) + op_library(spectral_op SRCS spectral_op.cc DEPS dynload_mklrt ${OP_HEADER_DEPS}) + target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE}) + else() + op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS}) + endif() endif() op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) diff --git a/paddle/fluid/operators/roll_op.cc b/paddle/fluid/operators/roll_op.cc index b6a8111592fb7..f82510556fde8 100644 --- a/paddle/fluid/operators/roll_op.cc +++ b/paddle/fluid/operators/roll_op.cc @@ -40,21 +40,23 @@ class RollOp : public framework::OperatorWithKernel { auto dims = ctx->Attrs().Get>("axis"); auto shifts = ctx->Attrs().Get>("shifts"); - if (dims.size() != 0) { - PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), - platform::errors::InvalidArgument( - "When dims.size() != 0, dims.size() " - "should be equal to " - "shifts.size(). But received " - "dims.size() = %d, shifts.size() = %d", - dims.size(), shifts.size())); - } else { - PADDLE_ENFORCE_EQ(shifts.size(), 1, - platform::errors::InvalidArgument( - "When dims.size() == 0, shifts.size() " - "should be equal to 1, But received " - "shifts.size() = %d", - shifts.size())); + if (!ctx->HasInput("ShiftsTensor")) { + if (dims.size() != 0) { + PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), + platform::errors::InvalidArgument( + "When dims.size() != 0, dims.size() " + "should be equal to " + "shifts.size(). But received " + "dims.size() = %d, shifts.size() = %d", + dims.size(), shifts.size())); + } else { + PADDLE_ENFORCE_EQ(shifts.size(), 1, + platform::errors::InvalidArgument( + "When dims.size() == 0, shifts.size() " + "should be equal to 1, But received " + "shifts.size() = %d", + shifts.size())); + } } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); @@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker { "The number of places by which the elements " "of the tensor are shifted.") .SetDefault({}); + AddInput("ShiftsTensor", + "The number of places by which the elements of the tensor " + "are shifted.") + .AsDispensable(); AddAttr>( "axis", "Axis along which to roll. It must have the same size " @@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr op) const override { op->SetType("roll_grad"); op->SetInput("X", this->Input("X")); + if (this->HasInput("ShiftsTensor")) { + op->SetInput("ShiftsTensor", this->Input("ShiftsTensor")); + } op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetAttrMap(this->Attrs()); @@ -174,7 +183,12 @@ REGISTER_OP_VERSION(roll) "(std::vector) Axis along which to roll. " "It must have the same size with shifts, or size = 0.", std::vector()) - .DeleteAttr( - "dims", - "(std::vector) Dims along which to roll. " - "It must have the same size with shifts, or size = 0.")); + .DeleteAttr("dims", + "(std::vector) Dims along which to roll. " + "It must have the same size with shifts, or size = 0.")) + .AddCheckpoint( + R"ROC(Upgrade roll add a dispensable input "ShiftsTensor".)ROC", + paddle::framework::compatible::OpVersionDesc().NewInput( + "ShiftsTensor", + "The number of places by which the elements of" + "the tensor are shifted.")); diff --git a/paddle/fluid/operators/roll_op.cu b/paddle/fluid/operators/roll_op.cu index a170ce2fb111d..d70bd58887f84 100644 --- a/paddle/fluid/operators/roll_op.cu +++ b/paddle/fluid/operators/roll_op.cu @@ -59,6 +59,16 @@ class RollKernel auto* in = context.Input("X"); auto* out = context.Output("Out"); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + PADDLE_ENFORCE_EQ( + shifts_tensor->dims().size(), 1, + platform::errors::InvalidArgument( + "The rank of ShiftsTensor is expected to be 1, got %s", + shifts_tensor->dims().size())); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); auto* in_data = in->data(); @@ -134,6 +144,16 @@ class RollGradKernel auto* in = context.Input(framework::GradVarName("Out")); auto* out = context.Output(framework::GradVarName("X")); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + PADDLE_ENFORCE_EQ( + shifts_tensor->dims().size(), 1, + platform::errors::InvalidArgument( + "The rank of ShiftsTensor is expected to be 1, got %s", + shifts_tensor->dims().size())); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); auto* in_data = in->data(); diff --git a/paddle/fluid/operators/roll_op.h b/paddle/fluid/operators/roll_op.h index e58ff521d8df7..affb5f226ed55 100644 --- a/paddle/fluid/operators/roll_op.h +++ b/paddle/fluid/operators/roll_op.h @@ -16,6 +16,8 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { @@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel { auto& input = input_var->Get(); auto* output = output_var->GetMutable(); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + PADDLE_ENFORCE_EQ( + shifts_tensor->dims().size(), 1, + platform::errors::InvalidArgument( + "The rank of ShiftsTensor is expected to be 1, got %s", + shifts_tensor->dims().size())); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); std::vector out_vec; @@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel { auto& input = input_var->Get(); auto* output = output_var->GetMutable(); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); std::vector out_vec; diff --git a/paddle/fluid/operators/spectral_helper.h b/paddle/fluid/operators/spectral_helper.h new file mode 100644 index 0000000000000..924ec7cd52d50 --- /dev/null +++ b/paddle/fluid/operators/spectral_helper.h @@ -0,0 +1,466 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/spectral_op.h" + +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/dynload/hipfft.h" +#endif + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/dynload/cufft.h" +#endif + +namespace paddle { +namespace operators { +using ScalarType = framework::proto::VarType::Type; +const int64_t kMaxFFTNdim = 3; +const int64_t kMaxDataNdim = kMaxFFTNdim + 1; +// This struct is used to easily compute hashes of the +// parameters. It will be the **key** to the plan cache. +struct FFTConfigKey { + // between 1 and kMaxFFTNdim, i.e., 1 <= signal_ndim <= 3 + int64_t signal_ndim_; + // These include additional batch dimension as well. + int64_t sizes_[kMaxDataNdim]; + int64_t input_shape_[kMaxDataNdim]; + int64_t output_shape_[kMaxDataNdim]; + FFTTransformType fft_type_; + ScalarType value_type_; + + FFTConfigKey() = default; + + FFTConfigKey(const std::vector& in_shape, + const std::vector& out_shape, + const std::vector& signal_size, + FFTTransformType fft_type, ScalarType value_type) { + // Padding bits must be zeroed for hashing + memset(this, 0, sizeof(*this)); + signal_ndim_ = signal_size.size() - 1; + fft_type_ = fft_type; + value_type_ = value_type; + + std::copy(signal_size.cbegin(), signal_size.cend(), sizes_); + std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_); + std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_); + } +}; + +#if defined(PADDLE_WITH_CUDA) +// An RAII encapsulation of cuFFTHandle +class CuFFTHandle { + ::cufftHandle handle_; + + public: + CuFFTHandle() { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftCreate(&handle_)); + } + + CuFFTHandle(const CuFFTHandle& other) = delete; + CuFFTHandle& operator=(const CuFFTHandle& other) = delete; + + CuFFTHandle(CuFFTHandle&& other) = delete; + CuFFTHandle& operator=(CuFFTHandle&& other) = delete; + + ::cufftHandle& get() { return handle_; } + const ::cufftHandle& get() const { return handle_; } + + ~CuFFTHandle() { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftDestroy(handle_)); + } +}; + +using plan_size_type = long long int; // NOLINT +// This class contains all the information needed to execute a cuFFT plan: +// 1. the plan +// 2. the workspace size needed +class FFTConfig { + public: + // Only move semantics is enought for this class. Although we already use + // unique_ptr for the plan, still remove copy constructor and assignment op so + // we don't accidentally copy and take perf hit. + explicit FFTConfig(const FFTConfigKey& plan_key) + : FFTConfig( + std::vector(plan_key.sizes_, + plan_key.sizes_ + plan_key.signal_ndim_ + 1), + plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} + + // sizes are full signal, including batch size and always two-sided + FFTConfig(const std::vector& sizes, const int64_t signal_ndim, + FFTTransformType fft_type, ScalarType dtype) + : fft_type_(fft_type), value_type_(dtype) { + // signal sizes (excluding batch dim) + std::vector signal_sizes(sizes.begin() + 1, sizes.end()); + + // input batch size + const auto batch = static_cast(sizes[0]); + // const int64_t signal_ndim = sizes.size() - 1; + PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1, + platform::errors::InvalidArgument( + "The signal_ndim must be equal to sizes.size() - 1," + "But signal_ndim is: [%d], sizes.size() - 1 is: [%d]", + signal_ndim, sizes.size() - 1)); + + cudaDataType itype, otype, exec_type; + const auto complex_input = has_complex_input(fft_type); + const auto complex_output = has_complex_output(fft_type); + if (dtype == framework::proto::VarType::FP32) { + itype = complex_input ? CUDA_C_32F : CUDA_R_32F; + otype = complex_output ? CUDA_C_32F : CUDA_R_32F; + exec_type = CUDA_C_32F; + } else if (dtype == framework::proto::VarType::FP64) { + itype = complex_input ? CUDA_C_64F : CUDA_R_64F; + otype = complex_output ? CUDA_C_64F : CUDA_R_64F; + exec_type = CUDA_C_64F; + } else if (dtype == framework::proto::VarType::FP16) { + itype = complex_input ? CUDA_C_16F : CUDA_R_16F; + otype = complex_output ? CUDA_C_16F : CUDA_R_16F; + exec_type = CUDA_C_16F; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "cuFFT only support transforms of type float16, float32 and " + "float64")); + } + + // disable auto allocation of workspace to use allocator from the framework + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftSetAutoAllocation( + plan(), /* autoAllocate */ 0)); + + size_t ws_size_t; + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftXtMakePlanMany( + plan(), signal_ndim, signal_sizes.data(), + /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype, + /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype, + batch, &ws_size_t, exec_type)); + + ws_size = ws_size_t; + } + + FFTConfig(const FFTConfig& other) = delete; + FFTConfig& operator=(const FFTConfig& other) = delete; + + FFTConfig(FFTConfig&& other) = delete; + FFTConfig& operator=(FFTConfig&& other) = delete; + + const cufftHandle& plan() const { return plan_ptr.get(); } + + FFTTransformType transform_type() const { return fft_type_; } + ScalarType data_type() const { return value_type_; } + size_t workspace_size() const { return ws_size; } + + private: + CuFFTHandle plan_ptr; + size_t ws_size; + FFTTransformType fft_type_; + ScalarType value_type_; +}; + +#elif defined(PADDLE_WITH_HIP) +// An RAII encapsulation of cuFFTHandle +class HIPFFTHandle { + ::hipfftHandle handle_; + + public: + HIPFFTHandle() { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftCreate(&handle_)); + } + + HIPFFTHandle(const HIPFFTHandle& other) = delete; + HIPFFTHandle& operator=(const HIPFFTHandle& other) = delete; + + HIPFFTHandle(HIPFFTHandle&& other) = delete; + HIPFFTHandle& operator=(HIPFFTHandle&& other) = delete; + + ::hipfftHandle& get() { return handle_; } + const ::hipfftHandle& get() const { return handle_; } + + ~HIPFFTHandle() { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftDestroy(handle_)); + } +}; +using plan_size_type = int; +// This class contains all the information needed to execute a cuFFT plan: +// 1. the plan +// 2. the workspace size needed +class FFTConfig { + public: + // Only move semantics is enought for this class. Although we already use + // unique_ptr for the plan, still remove copy constructor and assignment op so + // we don't accidentally copy and take perf hit. + explicit FFTConfig(const FFTConfigKey& plan_key) + : FFTConfig( + std::vector(plan_key.sizes_, + plan_key.sizes_ + plan_key.signal_ndim_ + 1), + plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} + + // sizes are full signal, including batch size and always two-sided + FFTConfig(const std::vector& sizes, const int64_t signal_ndim, + FFTTransformType fft_type, ScalarType dtype) + : fft_type_(fft_type), value_type_(dtype) { + // signal sizes (excluding batch dim) + std::vector signal_sizes(sizes.begin() + 1, sizes.end()); + + // input batch size + const auto batch = static_cast(sizes[0]); + // const int64_t signal_ndim = sizes.size() - 1; + PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1, + platform::errors::InvalidArgument( + "The signal_ndim must be equal to sizes.size() - 1," + "But signal_ndim is: [%d], sizes.size() - 1 is: [%d]", + signal_ndim, sizes.size() - 1)); + + hipfftType exec_type = [&] { + if (dtype == framework::proto::VarType::FP32) { + switch (fft_type) { + case FFTTransformType::C2C: + return HIPFFT_C2C; + case FFTTransformType::R2C: + return HIPFFT_R2C; + case FFTTransformType::C2R: + return HIPFFT_C2R; + } + } else if (dtype == framework::proto::VarType::FP64) { + switch (fft_type) { + case FFTTransformType::C2C: + return HIPFFT_Z2Z; + case FFTTransformType::R2C: + return HIPFFT_D2Z; + case FFTTransformType::C2R: + return HIPFFT_Z2D; + } + } + PADDLE_THROW(platform::errors::InvalidArgument( + "hipFFT only support transforms of type float32 and float64")); + }(); + + // disable auto allocation of workspace to use allocator from the framework + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftSetAutoAllocation( + plan(), /* autoAllocate */ 0)); + + size_t ws_size_t; + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftMakePlanMany( + plan(), signal_ndim, signal_sizes.data(), + /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, + /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, exec_type, + batch, &ws_size_t)); + + ws_size = ws_size_t; + } + + const hipfftHandle& plan() const { return plan_ptr.get(); } + + FFTTransformType transform_type() const { return fft_type_; } + ScalarType data_type() const { return value_type_; } + size_t workspace_size() const { return ws_size; } + + private: + HIPFFTHandle plan_ptr; + size_t ws_size; + FFTTransformType fft_type_; + ScalarType value_type_; +}; +#endif + +// Hashing machinery for Key +// Fowler–Noll–Vo hash function +// see +// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function +template +struct KeyHash { + // Key must be a POD because we read out its memory + // contenst as char* when hashing + static_assert(std::is_pod::value, "Key must be plain old data type"); + + size_t operator()(const Key& params) const { + auto ptr = reinterpret_cast(¶ms); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < static_cast(sizeof(Key)); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return static_cast(value); + } +}; + +template +struct KeyEqual { + // Key must be a POD because we read out its memory + // contenst as char* when comparing + static_assert(std::is_pod::value, "Key must be plain old data type"); + + bool operator()(const Key& a, const Key& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(Key)) == 0; + } +}; + +#if CUDA_VERSION < 10000 +// Note that the max plan number for CUDA version < 10 has to be 1023 +// due to a bug that fails on the 1024th plan +constexpr size_t CUFFT_MAX_PLAN_NUM = 1023; +constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM; +#else +constexpr size_t CUFFT_MAX_PLAN_NUM = std::numeric_limits::max(); +// The default max cache size chosen for CUDA version > 10 is arbitrary. +// This number puts a limit on how big of a plan cache should we maintain by +// default. Users can always configure it via cufft_set_plan_cache_max_size. +constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = 4096; +#endif +static_assert(CUFFT_MAX_PLAN_NUM >= 0 && + CUFFT_MAX_PLAN_NUM <= std::numeric_limits::max(), + "CUFFT_MAX_PLAN_NUM not in size_t range"); +static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 && + CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM, + "CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range"); + +// This cache assumes that the mapping from key to value never changes. +// This is **NOT** thread-safe. Please use a mutex when using it **AND** the +// value returned from try_emplace_value. +// The contract of using this cache is that try_emplace_value should only be +// used when the max_size is positive. +class FFTConfigCache { + public: + using kv_t = typename std::pair; + using map_t = typename std::unordered_map< + std::reference_wrapper, typename std::list::iterator, + KeyHash, KeyEqual>; + using map_kkv_iter_t = typename map_t::iterator; + + FFTConfigCache() : FFTConfigCache(CUFFT_DEFAULT_CACHE_SIZE) {} + + explicit FFTConfigCache(int64_t max_size) { _set_max_size(max_size); } + + FFTConfigCache(const FFTConfigCache& other) = delete; + FFTConfigCache& operator=(const FFTConfigCache& other) = delete; + + FFTConfigCache(FFTConfigCache&& other) noexcept + : _usage_list(std::move(other._usage_list)), + _cache_map(std::move(other._cache_map)), + _max_size(other._max_size) {} + + FFTConfigCache& operator=(FFTConfigCache&& other) noexcept { + _usage_list = std::move(other._usage_list); + _cache_map = std::move(other._cache_map); + _max_size = other._max_size; + return *this; + } + + // If key is in this cache, return the cached config. Otherwise, emplace the + // config in this cache and return it. + FFTConfig& lookup(FFTConfigKey params) { + PADDLE_ENFORCE_GT(_max_size, 0, + platform::errors::InvalidArgument( + "The max size of FFTConfigCache must be great than 0," + "But received is [%d]", + _max_size)); + + map_kkv_iter_t map_it = _cache_map.find(params); + // Hit, put to list front + if (map_it != _cache_map.end()) { + _usage_list.splice(_usage_list.begin(), _usage_list, map_it->second); + return map_it->second->second; + } + + // Miss + // remove if needed + if (_usage_list.size() >= _max_size) { + auto last = _usage_list.end(); + last--; + _cache_map.erase(last->first); + _usage_list.pop_back(); + } + + // construct new plan at list front, then insert into _cache_map + _usage_list.emplace_front(std::piecewise_construct, + std::forward_as_tuple(params), + std::forward_as_tuple(params)); + auto kv_it = _usage_list.begin(); + _cache_map.emplace(std::piecewise_construct, + std::forward_as_tuple(kv_it->first), + std::forward_as_tuple(kv_it)); + return kv_it->second; + } + + void clear() { + _cache_map.clear(); + _usage_list.clear(); + } + + void resize(int64_t new_size) { + _set_max_size(new_size); + auto cur_size = _usage_list.size(); + if (cur_size > _max_size) { + auto delete_it = _usage_list.end(); + for (size_t i = 0; i < cur_size - _max_size; i++) { + delete_it--; + _cache_map.erase(delete_it->first); + } + _usage_list.erase(delete_it, _usage_list.end()); + } + } + + size_t size() const { return _cache_map.size(); } + + size_t max_size() const noexcept { return _max_size; } + + std::mutex mutex; + + private: + // Only sets size and does value check. Does not resize the data structures. + void _set_max_size(int64_t new_size) { + // We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since + // CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check + // first. + PADDLE_ENFORCE_GE( + new_size, 0, + platform::errors::InvalidArgument( + "cuFFT plan cache size must be non-negative, But received is [%d]", + new_size)); + PADDLE_ENFORCE_LE(new_size, CUFFT_MAX_PLAN_NUM, + platform::errors::InvalidArgument( + "cuFFT plan cache size can not be larger than [%d], " + "But received is [%d]", + CUFFT_MAX_PLAN_NUM, new_size)); + _max_size = static_cast(new_size); + } + + std::list _usage_list; + map_t _cache_map; + size_t _max_size; +}; + +static std::vector> plan_caches; +static std::mutex plan_caches_mutex; + +static inline FFTConfigCache& get_fft_plan_cache(int64_t device_index) { + std::lock_guard guard(plan_caches_mutex); + + if (device_index >= plan_caches.size()) { + plan_caches.resize(device_index + 1); + } + + if (!plan_caches[device_index]) { + plan_caches[device_index] = std::make_unique(); + } + + return *plan_caches[device_index]; +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/spectral_op.cc b/paddle/fluid/operators/spectral_op.cc index fb50702233b3b..b5edc1dda533b 100644 --- a/paddle/fluid/operators/spectral_op.cc +++ b/paddle/fluid/operators/spectral_op.cc @@ -27,7 +27,7 @@ #include "paddle/fluid/platform/complex.h" #if defined(PADDLE_WITH_ONEMKL) -#include +#include "paddle/fluid/platform/dynload/mklrt.h" #elif defined(PADDLE_WITH_POCKETFFT) #include "extern_pocketfft/pocketfft_hdronly.h" #endif @@ -357,46 +357,45 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) { // FFT Functors #if defined(PADDLE_WITH_ONEMKL) +#define MKL_DFTI_CHECK(expr) \ + do { \ + MKL_LONG status = (expr); \ + if (!platform::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \ + PADDLE_THROW(platform::errors::External( \ + platform::dynload::DftiErrorMessage(status))); \ + } while (0); + namespace { -static inline void MKL_DFTI_CHECK(MKL_INT status) { - if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) { - PADDLE_THROW(platform::errors::External(DftiErrorMessage(status))); - } -} struct DftiDescriptorDeleter { void operator()(DFTI_DESCRIPTOR_HANDLE handle) { if (handle != nullptr) { - MKL_DFTI_CHECK(DftiFreeDescriptor(&handle)); + MKL_DFTI_CHECK(platform::dynload::DftiFreeDescriptor(&handle)); } } }; +// A RAII wrapper for MKL_DESCRIPTOR* class DftiDescriptor { public: void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG* sizes) { - if (desc_ != nullptr) { - PADDLE_THROW(platform::errors::AlreadyExists( - "DFT DESCRIPTOR can only be initialized once.")); - } + PADDLE_ENFORCE_EQ(desc_.get(), nullptr, + platform::errors::AlreadyExists( + "DftiDescriptor has already been initialized.")); + DFTI_DESCRIPTOR* raw_desc; - if (signal_ndim == 1) { - MKL_DFTI_CHECK( - DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0])); - } else { - MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, - signal_ndim, sizes)); - } + MKL_DFTI_CHECK(platform::dynload::DftiCreateDescriptorX( + &raw_desc, precision, signal_type, signal_ndim, sizes)); desc_.reset(raw_desc); } DFTI_DESCRIPTOR* get() const { - if (desc_ == nullptr) { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "DFTI DESCRIPTOR has not been initialized.")); - } - return desc_.get(); + DFTI_DESCRIPTOR* raw_desc = desc_.get(); + PADDLE_ENFORCE_NOT_NULL(raw_desc, + platform::errors::PreconditionNotMet( + "DFTI DESCRIPTOR has not been initialized.")); + return raw_desc; } private: @@ -421,7 +420,9 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, return DFTI_DOUBLE; default: PADDLE_THROW(platform::errors::InvalidArgument( - "Input data type should be FP32, FP64, COMPLEX64 or COMPLEX128.")); + "Invalid input datatype (%s), input data type should be FP32, " + "FP64, COMPLEX64 or COMPLEX128.", + framework::DataTypeToString(in_dtype))); } }(); @@ -430,35 +431,27 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, const DFTI_CONFIG_VALUE domain = (fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL; - // const bool complex_input = framework::IsComplexType(in_dtype); - // const bool complex_output = framework::IsComplexType(out_dtype); - // const DFTI_CONFIG_VALUE domain = [&] { - // if (forward) { - // return complex_input ? DFTI_COMPLEX : DFTI_REAL; - // } else { - // return complex_output ? DFTI_COMPLEX : DFTI_REAL; - // } - // }(); - DftiDescriptor descriptor; std::vector fft_sizes(signal_sizes.cbegin(), signal_sizes.cend()); const MKL_LONG signal_ndim = fft_sizes.size() - 1; descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1); // placement inplace or not inplace - MKL_DFTI_CHECK( - DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); // number of transformations const MKL_LONG batch_size = fft_sizes[0]; - MKL_DFTI_CHECK( - DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size)); // input & output distance const MKL_LONG idist = in_strides[0]; const MKL_LONG odist = out_strides[0]; - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist)); - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), + DFTI_INPUT_DISTANCE, idist)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), + DFTI_OUTPUT_DISTANCE, odist)); // input & output stride std::vector mkl_in_stride(1 + signal_ndim, 0); @@ -467,15 +460,15 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, mkl_in_stride[i] = in_strides[i]; mkl_out_stride[i] = out_strides[i]; } - MKL_DFTI_CHECK( - DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data())); - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, - mkl_out_stride.data())); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data())); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data())); // conjugate even storage if (!(fft_type == FFTTransformType::C2C)) { - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, - DFTI_COMPLEX_COMPLEX)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX)); } MKL_LONG signal_numel = @@ -496,11 +489,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, return DFTI_BACKWARD_SCALE; } }(); - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), + scale_direction, scale)); } // commit the descriptor - MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get())); + MKL_DFTI_CHECK(platform::dynload::DftiCommitDescriptor(descriptor.get())); return descriptor; } @@ -592,15 +586,16 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, collapsed_input.numel(), collapsed_input_conj.data()); for_range(functor); - MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), - collapsed_input_conj.data(), - collapsed_output.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( + desc.get(), collapsed_input_conj.data(), + collapsed_output.data())); } else if (fft_type == FFTTransformType::R2C && !forward) { framework::Tensor collapsed_output_conj(collapsed_output.type()); collapsed_output_conj.mutable_data(collapsed_output.dims(), ctx.GetPlace()); - MKL_DFTI_CHECK(DftiComputeForward(desc.get(), collapsed_input.data(), - collapsed_output_conj.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( + desc.get(), collapsed_input.data(), + collapsed_output_conj.data())); // conjugate the output platform::ForRange for_range(ctx, collapsed_output.numel()); math::ConjFunctor functor(collapsed_output_conj.data(), @@ -609,13 +604,13 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, for_range(functor); } else { if (forward) { - MKL_DFTI_CHECK(DftiComputeForward(desc.get(), - collapsed_input.data(), - collapsed_output.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( + desc.get(), collapsed_input.data(), + collapsed_output.data())); } else { - MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), - collapsed_input.data(), - collapsed_output.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( + desc.get(), collapsed_input.data(), + collapsed_output.data())); } } diff --git a/paddle/fluid/operators/spectral_op.cu b/paddle/fluid/operators/spectral_op.cu index 9aa5ca39d737e..dee5315b67fb4 100644 --- a/paddle/fluid/operators/spectral_op.cu +++ b/paddle/fluid/operators/spectral_op.cu @@ -8,10 +8,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#include -#include - #include #include #include @@ -24,313 +20,244 @@ #include #include "paddle/fluid/operators/conj_op.h" +#include "paddle/fluid/operators/spectral_helper.h" #include "paddle/fluid/operators/spectral_op.h" #include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/dynload/cufft.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { namespace { -using ScalarType = framework::proto::VarType::Type; -const int64_t kMaxCUFFTNdim = 3; -const int64_t kMaxDataNdim = kMaxCUFFTNdim + 1; - -static inline std::string get_cufft_error_info(cufftResult error) { - switch (error) { - case CUFFT_SUCCESS: - return "CUFFT_SUCCESS"; - case CUFFT_INVALID_PLAN: - return "CUFFT_INVALID_PLAN"; - case CUFFT_ALLOC_FAILED: - return "CUFFT_ALLOC_FAILED"; - case CUFFT_INVALID_TYPE: - return "CUFFT_INVALID_TYPE"; - case CUFFT_INVALID_VALUE: - return "CUFFT_INVALID_VALUE"; - case CUFFT_INTERNAL_ERROR: - return "CUFFT_INTERNAL_ERROR"; - case CUFFT_EXEC_FAILED: - return "CUFFT_EXEC_FAILED"; - case CUFFT_SETUP_FAILED: - return "CUFFT_SETUP_FAILED"; - case CUFFT_INVALID_SIZE: - return "CUFFT_INVALID_SIZE"; - case CUFFT_UNALIGNED_DATA: - return "CUFFT_UNALIGNED_DATA"; - case CUFFT_INCOMPLETE_PARAMETER_LIST: - return "CUFFT_INCOMPLETE_PARAMETER_LIST"; - case CUFFT_INVALID_DEVICE: - return "CUFFT_INVALID_DEVICE"; - case CUFFT_PARSE_ERROR: - return "CUFFT_PARSE_ERROR"; - case CUFFT_NO_WORKSPACE: - return "CUFFT_NO_WORKSPACE"; - case CUFFT_NOT_IMPLEMENTED: - return "CUFFT_NOT_IMPLEMENTED"; -#ifndef __HIPCC__ - case CUFFT_LICENSE_ERROR: - return "CUFFT_LICENSE_ERROR"; -#endif - case CUFFT_NOT_SUPPORTED: - return "CUFFT_NOT_SUPPORTED"; - default: - std::ostringstream ss; - ss << "unknown error " << error; - return ss.str(); +// Calculates the normalization constant +double fft_normalization_scale(FFTNormMode normalization, + const std::vector& sizes, + const std::vector& dims) { + // auto norm = static_cast(normalization); + if (normalization == FFTNormMode::none) { + return static_cast(1.0); } -} -static inline void CUFFT_CHECK(cufftResult error) { - if (error != CUFFT_SUCCESS) { - PADDLE_THROW(platform::errors::External(get_cufft_error_info(error))); + int64_t signal_numel = 1; + for (auto dim : dims) { + signal_numel *= sizes[dim]; } + const double scale_denom = (normalization == FFTNormMode::by_sqrt_n) + ? std::sqrt(signal_numel) + : static_cast(signal_numel); + return static_cast(1.0 / scale_denom); } -// This struct is used to easily compute hashes of the -// parameters. It will be the **key** to the plan cache. -struct PlanKey { - // between 1 and kMaxCUFFTNdim, i.e., 1 <= signal_ndim <= 3 - int64_t signal_ndim_; - // These include additional batch dimension as well. - int64_t sizes_[kMaxDataNdim]; - int64_t input_shape_[kMaxDataNdim]; - int64_t output_shape_[kMaxDataNdim]; - FFTTransformType fft_type_; - ScalarType value_type_; - - PlanKey() = default; - - PlanKey(const std::vector& in_shape, - const std::vector& out_shape, - const std::vector& signal_size, FFTTransformType fft_type, - ScalarType value_type) { - // Padding bits must be zeroed for hashing - memset(this, 0, sizeof(*this)); - signal_ndim_ = signal_size.size() - 1; - fft_type_ = fft_type; - value_type_ = value_type; - - std::copy(signal_size.cbegin(), signal_size.cend(), sizes_); - std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_); - std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_); +template +void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out, + FFTNormMode normalization, + const std::vector& sizes, + const std::vector& axes) { + double scale = fft_normalization_scale(normalization, sizes, axes); + if (scale != 1.0) { + auto eigen_out = framework::EigenVector::Flatten(*out); + auto eigen_in = framework::EigenVector::Flatten(*in); + auto dev = ctx.eigen_device(); + EigenScale::Eval(*dev, eigen_out, eigen_in, + static_cast(scale), + static_cast(0), false); + } else { + framework::TensorCopy(*in, ctx.GetPlace(), out); } -}; - -// An RAII encapsulation of cuFFTHandle -class CuFFTHandle { - ::cufftHandle handle_; - - public: - CuFFTHandle() { CUFFT_CHECK(platform::dynload::cufftCreate(&handle_)); } +} - ::cufftHandle& get() { return handle_; } - const ::cufftHandle& get() const { return handle_; } +#if defined(PADDLE_WITH_CUDA) +FFTConfigKey create_fft_configkey(const framework::Tensor& input, + const framework::Tensor& output, + int signal_ndim) { + // Create the transform plan (either from cache or locally) + const auto value_type = framework::IsComplexType(input.type()) + ? framework::ToRealType(input.type()) + : input.type(); + auto fft_type = GetFFTTransformType(input.type(), output.type()); + // signal sizes + std::vector signal_size(signal_ndim + 1); - ~CuFFTHandle() { -// Not using fftDestroy() for rocFFT to work around double freeing of handles -#ifndef __HIPCC__ - CUFFT_CHECK(platform::dynload::cufftDestroy(handle_)); -#endif + signal_size[0] = input.dims()[0]; + for (int64_t i = 1; i <= signal_ndim; ++i) { + auto in_size = input.dims()[i]; + auto out_size = output.dims()[i]; + signal_size[i] = std::max(in_size, out_size); } -}; - -#ifdef __HIPCC__ -using plan_size_type = int; -#else -using plan_size_type = long long int; // NOLINT -#endif + FFTConfigKey key(framework::vectorize(input.dims()), + framework::vectorize(output.dims()), signal_size, fft_type, + value_type); + return key; +} -// This class contains all the information needed to execute a cuFFT plan: -// 1. the plan -// 2. the workspace size needed -class CuFFTConfig { - public: - // Only move semantics is enought for this class. Although we already use - // unique_ptr for the plan, still remove copy constructor and assignment op so - // we don't accidentally copy and take perf hit. - CuFFTConfig(const CuFFTConfig&) = delete; - CuFFTConfig& operator=(CuFFTConfig const&) = delete; - - explicit CuFFTConfig(const PlanKey& plan_key) - : CuFFTConfig( - std::vector(plan_key.sizes_, - plan_key.sizes_ + plan_key.signal_ndim_ + 1), - plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} - - // sizes are full signal, including batch size and always two-sided - CuFFTConfig(const std::vector& sizes, const int64_t signal_ndim, - FFTTransformType fft_type, ScalarType dtype) - : fft_type_(fft_type), value_type_(dtype) { - // signal sizes (excluding batch dim) - std::vector signal_sizes(sizes.begin() + 1, sizes.end()); - - // input batch size - const auto batch = static_cast(sizes[0]); - // const int64_t signal_ndim = sizes.size() - 1; - PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1, - platform::errors::InvalidArgument( - "The signal_ndim must be equal to sizes.size() - 1," - "But signal_ndim is: [%d], sizes.size() - 1 is: [%d]", - signal_ndim, sizes.size() - 1)); - -#ifdef __HIPCC__ - hipfftType exec_type = [&] { - if (dtype == framework::proto::VarType::FP32) { - switch (fft_type) { - case FFTTransformType::C2C: - return HIPFFT_C2C; - case FFTTransformType::R2C: - return HIPFFT_R2C; - case FFTTransformType::C2R: - return HIPFFT_C2R; - } - } else if (dtype == framework::proto::VarType::FP64) { - switch (fft_type) { - case FFTTransformType::C2C: - return HIPFFT_Z2Z; - case FFTTransformType::R2C: - return HIPFFT_D2Z; - case FFTTransformType::C2R: - return HIPFFT_Z2D; - } - } - PADDLE_THROW(platform::errors::InvalidArgument( - "hipFFT only support transforms of type float32 and float64")); - }(); -#else - cudaDataType itype, otype, exec_type; - const auto complex_input = has_complex_input(fft_type); - const auto complex_output = has_complex_output(fft_type); - if (dtype == framework::proto::VarType::FP32) { - itype = complex_input ? CUDA_C_32F : CUDA_R_32F; - otype = complex_output ? CUDA_C_32F : CUDA_R_32F; - exec_type = CUDA_C_32F; - } else if (dtype == framework::proto::VarType::FP64) { - itype = complex_input ? CUDA_C_64F : CUDA_R_64F; - otype = complex_output ? CUDA_C_64F : CUDA_R_64F; - exec_type = CUDA_C_64F; - } else if (dtype == framework::proto::VarType::FP16) { - itype = complex_input ? CUDA_C_16F : CUDA_R_16F; - otype = complex_output ? CUDA_C_16F : CUDA_R_16F; - exec_type = CUDA_C_16F; - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "cuFFT only support transforms of type float16, float32 and " - "float64")); - } -#endif +// Execute a pre-planned transform +static void exec_cufft_plan_raw(const FFTConfig& config, void* in_data, + void* out_data, bool forward) { + auto& plan = config.plan(); - // disable auto allocation of workspace to use allocator from the framework - CUFFT_CHECK(platform::dynload::cufftSetAutoAllocation( - plan(), /* autoAllocate */ 0)); - - size_t ws_size_t; - -// make plan -#ifdef __HIPCC__ - CUFFT_CHECK(hipfftMakePlanMany( - plan(), signal_ndim, signal_sizes.data(), - /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, - /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, exec_type, - batch, &ws_size_t)); -#else - - CUFFT_CHECK(platform::dynload::cufftXtMakePlanMany( - plan(), signal_ndim, signal_sizes.data(), - /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype, - /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype, - batch, &ws_size_t, exec_type)); -#endif + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftXtExec( + plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE)); +} - ws_size = ws_size_t; +template +void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config, + framework::Tensor* input, framework::Tensor* output, + bool forward) { + // execute transform plan + auto fft_type = config.transform_type(); + if (fft_type == FFTTransformType::C2R && forward) { + forward = false; + framework::Tensor input_conj(input->type()); + input_conj.mutable_data(input->dims(), ctx.GetPlace()); + platform::ForRange for_range(ctx, input->numel()); + math::ConjFunctor functor(input->data(), input->numel(), + input_conj.data()); + for_range(functor); + exec_cufft_plan_raw(config, input_conj.data(), output->data(), + forward); + } else if (fft_type == FFTTransformType::R2C && !forward) { + forward = true; + framework::Tensor out_conj(output->type()); + out_conj.mutable_data(output->dims(), ctx.GetPlace()); + exec_cufft_plan_raw(config, input->data(), out_conj.data(), + forward); + + platform::ForRange for_range(ctx, output->numel()); + math::ConjFunctor functor(out_conj.data(), output->numel(), + output->data()); + for_range(functor); + } else { + exec_cufft_plan_raw(config, input->data(), output->data(), + forward); } +} - const cufftHandle& plan() const { return plan_ptr.get(); } +#elif defined(PADDLE_WITH_HIP) - FFTTransformType transform_type() const { return fft_type_; } - ScalarType data_type() const { return value_type_; } - size_t workspace_size() const { return ws_size; } +FFTConfigKey create_fft_configkey(const framework::Tensor& input, + const framework::Tensor& output, + int signal_ndim) { + // Create the transform plan (either from cache or locally) + const auto value_type = framework::IsComplexType(input.type()) + ? framework::ToRealType(input.type()) + : input.type(); + auto fft_type = GetFFTTransformType(input.type(), output.type()); + // signal sizes + std::vector signal_size(signal_ndim + 1); - private: - CuFFTHandle plan_ptr; - size_t ws_size; - FFTTransformType fft_type_; - ScalarType value_type_; -}; + signal_size[0] = input.dims()[0]; + for (int64_t i = 1; i <= signal_ndim; ++i) { + auto in_size = input.dims()[i]; + auto out_size = output.dims()[i]; + signal_size[i] = std::max(in_size, out_size); + } + FFTConfigKey key(framework::vectorize(input.dims()), + framework::vectorize(output.dims()), signal_size, fft_type, + value_type); + return key; +} // Execute a pre-planned transform -static void exec_cufft_plan(const CuFFTConfig& config, void* in_data, - void* out_data, bool forward) { +static void exec_hipfft_plan_raw(const FFTConfig& config, void* in_data, + void* out_data, bool forward) { auto& plan = config.plan(); -#ifdef __HIPCC__ + auto value_type = config.data_type(); if (value_type == framework::proto::VarType::FP32) { switch (config.transform_type()) { case FFTTransformType::C2C: { - CUFFT_CHECK(hipfftExecC2C(plan, static_cast(in_data), - static_cast(out_data), - forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecC2C( + plan, static_cast(in_data), + static_cast(out_data), + forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); return; } case FFTTransformType::R2C: { - CUFFT_CHECK(hipfftExecR2C(plan, static_cast(in_data), - static_cast(out_data))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecR2C( + plan, static_cast(in_data), + static_cast(out_data))); return; } case FFTTransformType::C2R: { - CUFFT_CHECK(hipfftExecC2R(plan, static_cast(in_data), - static_cast(out_data))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecC2R( + plan, static_cast(in_data), + static_cast(out_data))); return; } } } else if (value_type == framework::proto::VarType::FP64) { switch (config.transform_type()) { case FFTTransformType::C2C: { - CUFFT_CHECK(hipfftExecZ2Z(plan, - static_cast(in_data), - static_cast(out_data), - forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecZ2Z( + plan, static_cast(in_data), + static_cast(out_data), + forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); return; } case FFTTransformType::R2C: { - CUFFT_CHECK(hipfftExecD2Z(plan, static_cast(in_data), - static_cast(out_data))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecD2Z( + plan, static_cast(in_data), + static_cast(out_data))); return; } case FFTTransformType::C2R: { - CUFFT_CHECK(hipfftExecZ2D(plan, - static_cast(in_data), - static_cast(out_data))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecZ2D( + plan, static_cast(in_data), + static_cast(out_data))); return; } } } PADDLE_THROW(platform::errors::InvalidArgument( "hipFFT only support transforms of type float32 and float64")); -#else - CUFFT_CHECK(platform::dynload::cufftXtExec( - plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE)); -#endif } +template +void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config, + framework::Tensor* input, framework::Tensor* output, + bool forward) { + auto fft_type = config.transform_type(); + if (fft_type == FFTTransformType::C2R && forward) { + forward = false; + framework::Tensor input_conj(input->type()); + input_conj.mutable_data(input->dims(), ctx.GetPlace()); + platform::ForRange for_range(ctx, input->numel()); + math::ConjFunctor functor(input->data(), input->numel(), + input_conj.data()); + for_range(functor); + exec_hipfft_plan_raw(config, input_conj.data(), output->data(), + forward); + } else if (fft_type == FFTTransformType::R2C && !forward) { + forward = true; + framework::Tensor out_conj(output->type()); + out_conj.mutable_data(output->dims(), ctx.GetPlace()); + exec_hipfft_plan_raw(config, input->data(), out_conj.data(), + forward); + + platform::ForRange for_range(ctx, output->numel()); + math::ConjFunctor functor(out_conj.data(), output->numel(), + output->data()); + for_range(functor); + } else { + exec_hipfft_plan_raw(config, input->data(), output->data(), + forward); + } +} + +#endif + // Execute a general unnormalized fft operation (can be c2c, onesided r2c or // onesided c2r) template void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, const std::vector& dim, bool forward) { const auto x_dims = framework::vectorize(X->dims()); - const auto out_dims = framework::vectorize(out->dims()); const int64_t ndim = static_cast(X->dims().size()); - const int64_t signal_ndim = static_cast(dim.size()); - const int64_t batch_dims = ndim - signal_ndim; auto tensor_place = ctx.GetPlace(); - // Transpose batch dimensions first, then with transforming dims + // make a dim permutation std::vector dim_permute(ndim); - std::vector reverse_dim_permute(ndim); - std::vector trans_dims(ndim); std::iota(dim_permute.begin(), dim_permute.end(), int{0}); std::vector is_transformed_dim(ndim); for (const auto& d : dim) { @@ -342,167 +269,120 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, std::sort(dim_permute.begin(), batch_end); std::copy(dim.cbegin(), dim.cend(), batch_end); - for (size_t i = 0; i < ndim; i++) { - trans_dims[i] = x_dims[dim_permute[i]]; // shape of input transpose - reverse_dim_permute[dim_permute[i]] = - static_cast(i); // reverse of dim permute - } - framework::Tensor input; - input.Resize(framework::make_ddim(trans_dims)); - input.mutable_data(tensor_place); - /* - auto in_ret = TransposeSimple::run(ctx, *X, dim_permute, input); - if (!in_ret) { - TransCompute(ndim, ctx, *X, input, dim_permute); - } - */ - TransCompute(ndim, ctx, *X, &input, dim_permute); + // transpose input according to dim permutation + auto transposed_input_shape = X->dims().transpose(dim_permute); + framework::Tensor transposed_input; + transposed_input.Resize(transposed_input_shape); + transposed_input.mutable_data(tensor_place); + TransCompute(ndim, ctx, *X, &transposed_input, + dim_permute); // Reshape batch dimensions into a single dimension - std::vector batched_sizes(signal_ndim + 1); + const int64_t signal_ndim = static_cast(dim.size()); + std::vector collapsed_input_shape(signal_ndim + 1); + + auto transposed_input_shape_ = framework::vectorize(transposed_input_shape); + const int64_t batch_dims = ndim - signal_ndim; auto batch_size = - std::accumulate(trans_dims.begin(), trans_dims.begin() + batch_dims, + std::accumulate(transposed_input_shape_.begin(), + transposed_input_shape_.begin() + batch_dims, static_cast(1), std::multiplies()); - batched_sizes[0] = batch_size; - std::copy(trans_dims.begin() + batch_dims, trans_dims.end(), - batched_sizes.begin() + 1); - input.Resize(framework::make_ddim(batched_sizes)); + collapsed_input_shape[0] = batch_size; - // Check the shape of transforming dims with input and output - std::vector signal_size(signal_ndim + 1); - signal_size[0] = batch_size; - for (int64_t i = 0; i < signal_ndim; ++i) { - auto in_size = input.dims()[i + 1]; - auto out_size = out_dims[dim[i]]; - signal_size[i + 1] = std::max(in_size, out_size); - PADDLE_ENFORCE_EQ( - (in_size == signal_size[i + 1] || - in_size == (signal_size[i + 1] / 2) + 1), - true, - platform::errors::InvalidArgument( - "The dimension[%d] of Input size: [%d] must be equal or half to " - "The dimension[%d] of Output size: [%d]", - dim[i], in_size, dim[i], out_size)); - PADDLE_ENFORCE_EQ( - (out_size == signal_size[i + 1] || - out_size == (signal_size[i + 1] / 2) + 1), - true, - platform::errors::InvalidArgument( - "The dimension[%d] of Output size: [%d] must be equal or half to " - "The dimension[%d] of Input size: [%d]", - dim[i], out_size, dim[i], in_size)); - } + std::copy(transposed_input_shape_.begin() + batch_dims, + transposed_input_shape_.end(), collapsed_input_shape.begin() + 1); - std::vector reshape_out_sizes(ndim); - for (size_t i = 0; i < ndim; ++i) { - reshape_out_sizes[i] = out_dims[dim_permute[i]]; - } - std::vector batched_out_sizes(batched_sizes.begin(), - batched_sizes.end()); + framework::Tensor& collapsed_input = transposed_input; + collapsed_input.Resize(framework::make_ddim(collapsed_input_shape)); + + // make a collpased output + const auto out_dims = framework::vectorize(out->dims()); + std::vector collapsed_output_shape(1 + signal_ndim); + collapsed_output_shape[0] = batch_size; for (size_t i = 0; i < dim.size(); ++i) { - batched_out_sizes[i + 1] = out_dims[dim[i]]; + collapsed_output_shape[i + 1] = out_dims[dim[i]]; + } + framework::Tensor collapsed_output; + collapsed_output.Resize(framework::make_ddim(collapsed_output_shape)); + collapsed_output.mutable_data(tensor_place); + + FFTConfig* config = nullptr; + +#if defined(PADDLE_WITH_CUDA) + std::unique_ptr config_ = nullptr; + // create plan + FFTConfigKey key = + create_fft_configkey(collapsed_input, collapsed_output, signal_ndim); + if (CUFFT_VERSION < 10200) { + const int64_t device_id = static_cast( + reinterpret_cast(&collapsed_input.place()) + ->GetDeviceId()); + FFTConfigCache& plan_cache = get_fft_plan_cache(device_id); + std::unique_lock guard(plan_cache.mutex, std::defer_lock); + guard.lock(); + config = &(plan_cache.lookup(key)); + } else { + config_ = std::make_unique(key); + config = config_.get(); } - - // output - framework::Tensor output; - output.Resize(framework::make_ddim(batched_out_sizes)); - output.mutable_data(tensor_place); - - // Create the transform plan (either from cache or locally) - const auto value_type = framework::IsComplexType(input.type()) - ? framework::ToRealType(input.type()) - : input.type(); - auto fft_type = GetFFTTransformType(input.type(), output.type()); - PlanKey Key(framework::vectorize(input.dims()), - framework::vectorize(output.dims()), signal_size, fft_type, - value_type); - CuFFTConfig uncached_plan(Key); - CuFFTConfig* config = &uncached_plan; - auto& plan = config->plan(); - // prepare cufft for execution - CUFFT_CHECK(platform::dynload::cufftSetStream(plan, ctx.stream())); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cufftSetStream(config->plan(), ctx.stream())); framework::Tensor workspace_tensor; workspace_tensor.mutable_data(tensor_place, config->workspace_size()); - CUFFT_CHECK( - platform::dynload::cufftSetWorkArea(plan, workspace_tensor.data())); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftSetWorkArea( + config->plan(), workspace_tensor.data())); + // execute transform plan + exec_cufft_plan(ctx, *config, &collapsed_input, + &collapsed_output, forward); + +#elif defined(PADDLE_WITH_HIP) + // create plan + FFTConfigKey key = + create_fft_configkey(collapsed_input, collapsed_output, signal_ndim); + const int64_t device_id = static_cast( + reinterpret_cast(&collapsed_input.place()) + ->GetDeviceId()); + FFTConfigCache& plan_cache = get_fft_plan_cache(device_id); + std::unique_lock guard(plan_cache.mutex, std::defer_lock); + guard.lock(); + config = &(plan_cache.lookup(key)); + // prepare cufft for execution + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::hipfftSetStream(config->plan(), ctx.stream())); + framework::Tensor workspace_tensor; + workspace_tensor.mutable_data(tensor_place, config->workspace_size()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftSetWorkArea( + config->plan(), workspace_tensor.data())); // execute transform plan - if (fft_type == FFTTransformType::C2R && forward) { - forward = false; - framework::Tensor input_conj(input.type()); - input_conj.mutable_data(input.dims(), ctx.GetPlace()); - platform::ForRange for_range(ctx, input.numel()); - math::ConjFunctor functor(input.data(), input.numel(), - input_conj.data()); - for_range(functor); - exec_cufft_plan(*config, input_conj.data(), output.data(), - forward); - } else if (fft_type == FFTTransformType::R2C && !forward) { - forward = true; - framework::Tensor out_conj(output.type()); - out_conj.mutable_data(output.dims(), ctx.GetPlace()); - exec_cufft_plan(*config, input.data(), out_conj.data(), - forward); - - platform::ForRange for_range(ctx, output.numel()); - math::ConjFunctor functor(out_conj.data(), output.numel(), - output.data()); - for_range(functor); - } else { - exec_cufft_plan(*config, input.data(), output.data(), forward); - } + exec_hipfft_plan(ctx, *config, &collapsed_input, + &collapsed_output, forward); +#endif // Inverting output by reshape and transpose to original batch and dimension - output.Resize(framework::make_ddim(reshape_out_sizes)); - out->Resize(framework::make_ddim(out_dims)); - TransCompute(ndim, ctx, output, out, reverse_dim_permute); -} + auto transposed_out_shape = out->dims().transpose(dim_permute); -// Calculates the normalization constant -double fft_normalization_scale(FFTNormMode normalization, - const std::vector& sizes, - const std::vector& dims) { - // auto norm = static_cast(normalization); - if (normalization == FFTNormMode::none) { - return static_cast(1.0); - } + collapsed_output.Resize(transposed_out_shape); + auto& transposed_output = collapsed_output; - int64_t signal_numel = 1; - for (auto dim : dims) { - signal_numel *= sizes[dim]; + std::vector reverse_dim_permute(ndim); + for (size_t i = 0; i < ndim; i++) { + reverse_dim_permute[dim_permute[i]] = i; } - const double scale_denom = (normalization == FFTNormMode::by_sqrt_n) - ? std::sqrt(signal_numel) - : static_cast(signal_numel); - return static_cast(1.0 / scale_denom); -} -template -void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out, - FFTNormMode normalization, - const std::vector& sizes, - const std::vector& axes) { - double scale = fft_normalization_scale(normalization, sizes, axes); - if (scale != 1.0) { - auto eigen_out = framework::EigenVector::Flatten(*out); - auto eigen_in = framework::EigenVector::Flatten(*in); - auto dev = ctx.eigen_device(); - EigenScale::Eval(*dev, eigen_out, eigen_in, - static_cast(scale), - static_cast(0), false); - } else { - framework::TensorCopy(*in, ctx.GetPlace(), out); - } + TransCompute(ndim, ctx, transposed_output, out, + reverse_dim_permute); } + } // anonymous namespace // Use the optimized path to perform single R2C or C2R if transformation dim is // supported by cuFFT -bool use_optimized_cufft_path(const std::vector& axes) { +bool use_optimized_fft_path(const std::vector& axes) { // For performance reason, when axes starts with (0, 1), do not use the // optimized path. - if (axes.size() > kMaxCUFFTNdim || + if (axes.size() > kMaxFFTNdim || (axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) { return false; } else { @@ -532,7 +412,7 @@ struct FFTC2CFunctor { while (true) { max_dims = - std::min(static_cast(kMaxCUFFTNdim), working_axes.size()); + std::min(static_cast(kMaxFFTNdim), working_axes.size()); first_dims.assign(working_axes.end() - max_dims, working_axes.end()); exec_fft(ctx, p_working_tensor, @@ -559,7 +439,7 @@ struct FFTC2RFunctor { std::vector in_dims = framework::vectorize(X->dims()); std::vector out_dims = framework::vectorize(out->dims()); - if (use_optimized_cufft_path(axes)) { + if (use_optimized_fft_path(axes)) { framework::Tensor x_copy(X->type()); x_copy.mutable_data(X->dims(), ctx.GetPlace()); framework::TensorCopy(*X, ctx.GetPlace(), &x_copy); diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index c0d4b349a9e09..6e90ccfc51e1b 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -7,7 +7,7 @@ if (NOT WITH_NV_JETSON) endif() if (WITH_ROCM) - list(APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc) + list(APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc hipfft.cc) endif() # There is no macOS version of NCCL. @@ -49,3 +49,9 @@ endif() cc_library(dynload_lapack SRCS lapack.cc DEPS dynamic_loader) add_dependencies(dynload_lapack extern_lapack) # TODO(TJ): add iomp, mkldnn? + +if (MKL_FOUND AND WITH_ONEMKL) + message("ONEMKL INCLUDE directory is ${MKL_INCLUDE}") + cc_library(dynload_mklrt SRCS mklrt.cc DEPS dynamic_loader) + target_include_directories(dynload_mklrt PRIVATE ${MKL_INCLUDE}) +endif() diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index a83f085f7d2d8..1bfd48b133907 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -53,6 +53,12 @@ DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so."); DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so."); +DEFINE_string(mkl_dir, "", + "Specify path for loading libmkl_rt.so. " + "For insrance, /opt/intel/oneapi/mkl/latest/lib/intel64/." + "If default, " + "dlopen will search mkl from LD_LIBRARY_PATH"); + DEFINE_string(op_dir, "", "Specify path for loading user-defined op library."); #ifdef PADDLE_WITH_HIP @@ -350,6 +356,16 @@ void* GetCurandDsoHandle() { #endif } +#ifdef PADDLE_WITH_HIP +void* GetROCFFTDsoHandle() { +#if defined(__APPLE__) || defined(__OSX__) + return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocfft.dylib"); +#else + return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocfft.so"); +#endif +} +#endif + void* GetNvjpegDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvjpeg.dylib"); @@ -518,6 +534,16 @@ void* GetCUFFTDsoHandle() { #endif } +void* GetMKLRTDsoHandle() { +#if defined(__APPLE__) || defined(__OSX__) + return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.dylib"); +#elif defined(_WIN32) + return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "mkl_rt.dll"); +#else + return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.so"); +#endif +} + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index 82c36d9e224f4..1a66f4b979207 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -43,6 +43,8 @@ void* GetLAPACKDsoHandle(); void* GetOpDsoHandle(const std::string& dso_name); void* GetNvtxDsoHandle(); void* GetCUFFTDsoHandle(); +void* GetMKLRTDsoHandle(); +void* GetROCFFTDsoHandle(); void SetPaddleLibPath(const std::string&); } // namespace dynload diff --git a/paddle/fluid/platform/dynload/hipfft.cc b/paddle/fluid/platform/dynload/hipfft.cc new file mode 100644 index 0000000000000..767d2161be9d8 --- /dev/null +++ b/paddle/fluid/platform/dynload/hipfft.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/hipfft.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag hipfft_dso_flag; +void *hipfft_dso_handle; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +HIPFFT_FFT_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/hipfft.h b/paddle/fluid/platform/dynload/hipfft.h new file mode 100644 index 0000000000000..50c25935e41b7 --- /dev/null +++ b/paddle/fluid/platform/dynload/hipfft.h @@ -0,0 +1,124 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#ifdef PADDLE_WITH_HIP +#include + +#include // NOLINT + +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/platform/port.h" + +namespace paddle { +namespace platform { +namespace dynload { +extern std::once_flag hipfft_dso_flag; +extern void *hipfft_dso_handle; + +#define DECLARE_DYNAMIC_LOAD_HIPFFT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using hipfftFunc = decltype(&::__name); \ + std::call_once(hipfft_dso_flag, []() { \ + hipfft_dso_handle = paddle::platform::dynload::GetROCFFTDsoHandle(); \ + }); \ + static void *p_##__name = dlsym(hipfft_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +#define HIPFFT_FFT_ROUTINE_EACH(__macro) \ + __macro(hipfftPlan1d); \ + __macro(hipfftPlan2d); \ + __macro(hipfftPlan3d); \ + __macro(hipfftPlanMany); \ + __macro(hipfftMakePlan1d); \ + __macro(hipfftMakePlanMany); \ + __macro(hipfftMakePlanMany64); \ + __macro(hipfftGetSizeMany64); \ + __macro(hipfftEstimate1d); \ + __macro(hipfftEstimate2d); \ + __macro(hipfftEstimate3d); \ + __macro(hipfftEstimateMany); \ + __macro(hipfftCreate); \ + __macro(hipfftGetSize1d); \ + __macro(hipfftGetSizeMany); \ + __macro(hipfftGetSize); \ + __macro(hipfftSetWorkArea); \ + __macro(hipfftSetAutoAllocation); \ + __macro(hipfftExecC2C); \ + __macro(hipfftExecR2C); \ + __macro(hipfftExecC2R); \ + __macro(hipfftExecZ2Z); \ + __macro(hipfftExecD2Z); \ + __macro(hipfftExecZ2D); \ + __macro(hipfftSetStream); \ + __macro(hipfftDestroy); \ + __macro(hipfftGetVersion); \ + __macro(hipfftGetProperty); + +HIPFFT_FFT_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_HIPFFT_WRAP); + +inline const char *hipfftGetErrorString(hipfftResult_t status) { + switch (status) { + case HIPFFT_SUCCESS: + return "'HIPFFT_SUCCESS'. The hipFFT operation was successful."; + case HIPFFT_INVALID_PLAN: + return "'HIPFFT_INVALID_PLAN'. hipFFT was passed an invalid plan handle."; + case HIPFFT_ALLOC_FAILED: + return "'HIPFFT_ALLOC_FAILED'. hipFFT failed to allocate GPU or CPU " + "memory."; + case HIPFFT_INVALID_TYPE: + return "'HIPFFT_INVALID_TYPE'. No longer used."; + case HIPFFT_INVALID_VALUE: + return "'HIPFFT_INVALID_VALUE'. User specified an invalid pointer or " + "parameter."; + case HIPFFT_INTERNAL_ERROR: + return "'HIPFFT_INTERNAL_ERROR'. Driver or internal hipFFT library " + "error."; + case HIPFFT_EXEC_FAILED: + return "'HIPFFT_EXEC_FAILED'. Failed to execute an FFT on the GPU."; + case HIPFFT_SETUP_FAILED: + return "'HIPFFT_SETUP_FAILED'. The hipFFT library failed to initialize."; + case HIPFFT_INVALID_SIZE: + return "'HIPFFT_INVALID_SIZE'. User specified an invalid transform size."; + case HIPFFT_UNALIGNED_DATA: + return "'HIPFFT_UNALIGNED_DATA'. No longer used."; + case HIPFFT_INCOMPLETE_PARAMETER_LIST: + return "'HIPFFT_INCOMPLETE_PARAMETER_LIST'. Missing parameters in call."; + case HIPFFT_INVALID_DEVICE: + return "'HIPFFT_INVALID_DEVICE'. Execution of a plan was on different " + "GPU than plan creation."; + case HIPFFT_PARSE_ERROR: + return "'HIPFFT_PARSE_ERROR'. Internal plan database error."; + case HIPFFT_NO_WORKSPACE: + return "'HIPFFT_NO_WORKSPACE'. No workspace has been provided prior to " + "plan execution."; + case HIPFFT_NOT_IMPLEMENTED: + return "'HIPFFT_NOT_IMPLEMENTED'. Function does not implement " + "functionality for parameters given."; + case HIPFFT_NOT_SUPPORTED: + return "'HIPFFT_NOT_SUPPORTED'. Operation is not supported for " + "parameters given."; + default: + return "HIPFFT_STATUS_UNKNOWN_ERROR"; + } +} +} // namespace dynload +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/dynload/mklrt.cc b/paddle/fluid/platform/dynload/mklrt.cc new file mode 100644 index 0000000000000..45fad15fb583e --- /dev/null +++ b/paddle/fluid/platform/dynload/mklrt.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/mklrt.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag mklrt_dso_flag; +void* mklrt_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +MKLDFTI_ROUTINE_EACH(DEFINE_WRAP); + +DFTI_EXTERN MKL_LONG DftiCreateDescriptorX(DFTI_DESCRIPTOR_HANDLE* desc, + enum DFTI_CONFIG_VALUE prec, + enum DFTI_CONFIG_VALUE domain, + MKL_LONG dim, MKL_LONG* sizes) { + if (prec == DFTI_SINGLE) { + if (dim == 1) { + return DftiCreateDescriptor_s_1d(desc, domain, sizes[0]); + } else { + return DftiCreateDescriptor_s_md(desc, domain, dim, sizes); + } + } else if (prec == DFTI_DOUBLE) { + if (dim == 1) { + return DftiCreateDescriptor_d_1d(desc, domain, sizes[0]); + } else { + return DftiCreateDescriptor_d_md(desc, domain, dim, sizes); + } + } else { + return DftiCreateDescriptor(desc, prec, domain, dim, sizes); + } +} + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/mklrt.h b/paddle/fluid/platform/dynload/mklrt.h new file mode 100644 index 0000000000000..423cd4d0a254c --- /dev/null +++ b/paddle/fluid/platform/dynload/mklrt.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include // NOLINT + +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/platform/port.h" + +namespace paddle { +namespace platform { +namespace dynload { + +extern std::once_flag mklrt_dso_flag; +extern void* mklrt_dso_handle; + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load mkldfti routine + * via operator overloading. + */ +#define DYNAMIC_LOAD_MKLRT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using mklrtFunc = decltype(&::__name); \ + std::call_once(mklrt_dso_flag, []() { \ + mklrt_dso_handle = paddle::platform::dynload::GetMKLRTDsoHandle(); \ + }); \ + static void* p_##__name = dlsym(mklrt_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +// mkl_dfti.h has a macro that shadows the function with the same name +// un-defeine this macro so as to export that function +#undef DftiCreateDescriptor + +#define MKLDFTI_ROUTINE_EACH(__macro) \ + __macro(DftiCreateDescriptor); \ + __macro(DftiCreateDescriptor_s_1d); \ + __macro(DftiCreateDescriptor_d_1d); \ + __macro(DftiCreateDescriptor_s_md); \ + __macro(DftiCreateDescriptor_d_md); \ + __macro(DftiSetValue); \ + __macro(DftiGetValue); \ + __macro(DftiCommitDescriptor); \ + __macro(DftiComputeForward); \ + __macro(DftiComputeBackward); \ + __macro(DftiFreeDescriptor); \ + __macro(DftiErrorClass); \ + __macro(DftiErrorMessage); + +MKLDFTI_ROUTINE_EACH(DYNAMIC_LOAD_MKLRT_WRAP) + +#undef DYNAMIC_LOAD_MKLRT_WRAP + +// define another function to avoid naming conflict +DFTI_EXTERN MKL_LONG DftiCreateDescriptorX(DFTI_DESCRIPTOR_HANDLE* desc, + enum DFTI_CONFIG_VALUE prec, + enum DFTI_CONFIG_VALUE domain, + MKL_LONG dim, MKL_LONG* sizes); + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index c420a5a64be06..caa495bb7f8c5 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -31,6 +31,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include #include +#include #include #include #include @@ -85,6 +86,7 @@ limitations under the License. */ #endif // PADDLE_WITH_CUDA #ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/dynload/hipfft.h" #include "paddle/fluid/platform/dynload/hiprand.h" #include "paddle/fluid/platform/dynload/miopen.h" #include "paddle/fluid/platform/dynload/rocblas.h" @@ -714,6 +716,7 @@ DEFINE_EXTERNAL_API_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS, CURAND); DEFINE_EXTERNAL_API_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS, CUDNN); DEFINE_EXTERNAL_API_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS, CUBLAS); DEFINE_EXTERNAL_API_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS, CUSOLVER); +DEFINE_EXTERNAL_API_TYPE(cufftResult_t, CUFFT_SUCCESS, CUFFT); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess, NCCL); @@ -751,6 +754,8 @@ inline const char* GetErrorMsgUrl(T status) { return "https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/" "types.html#ncclresult-t"; break; + case platform::proto::ApiType::CUFFT: + return "https://docs.nvidia.com/cuda/cufft/index.html#cufftresult"; default: return "Unknown type of External API, can't get error message URL!"; break; @@ -839,6 +844,7 @@ template std::string GetExternalErrorMsg(curandStatus_t); template std::string GetExternalErrorMsg(cudnnStatus_t); template std::string GetExternalErrorMsg(cublasStatus_t); template std::string GetExternalErrorMsg(cusolverStatus_t); +template std::string GetExternalErrorMsg(cufftResult_t); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) template std::string GetExternalErrorMsg(ncclResult_t); #endif @@ -899,6 +905,15 @@ inline std::string build_nvidia_error_msg(cusolverStatus_t stat) { return sout.str(); } +/*************** CUFFT ERROR ***************/ +inline bool is_error(cufftResult_t stat) { return stat != CUFFT_SUCCESS; } + +inline std::string build_nvidia_error_msg(cufftResult_t stat) { + std::ostringstream sout; + sout << "CUFFT error(" << stat << "). " << GetExternalErrorMsg(stat); + return sout.str(); +} + /**************** NCCL ERROR ****************/ #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) inline bool is_error(ncclResult_t nccl_result) { @@ -1099,6 +1114,14 @@ inline std::string build_rocm_error_msg(ncclResult_t nccl_result) { } #endif // not(__APPLE__) and PADDLE_WITH_NCCL +/***** HIPFFT ERROR *****/ +inline bool is_error(hipfftResult_t stat) { return stat != HIPFFT_SUCCESS; } + +inline std::string build_rocm_error_msg(hipfftResult_t stat) { + std::string msg(" HIPFFT error, "); + return msg + platform::dynload::hipfftGetErrorString(stat) + " "; +} + namespace details { template @@ -1115,6 +1138,7 @@ DEFINE_EXTERNAL_API_TYPE(hipError_t, hipSuccess); DEFINE_EXTERNAL_API_TYPE(hiprandStatus_t, HIPRAND_STATUS_SUCCESS); DEFINE_EXTERNAL_API_TYPE(miopenStatus_t, miopenStatusSuccess); DEFINE_EXTERNAL_API_TYPE(rocblas_status, rocblas_status_success); +DEFINE_EXTERNAL_API_TYPE(hipfftResult_t, HIPFFT_SUCCESS); #if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL) DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess); diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index 95a852ad6e92a..6ff9e6ea903cd 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -9,10 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/platform/enforce.h" + #include #include "gtest/gtest.h" -#include "paddle/fluid/platform/enforce.h" TEST(ENFORCE, OK) { PADDLE_ENFORCE(true, paddle::platform::errors::Unavailable( @@ -330,6 +331,10 @@ TEST(enforce, hip_success) { CheckCudaStatusFailure(rocblas_status_invalid_handle, "Rocblas error")); EXPECT_TRUE( CheckCudaStatusFailure(rocblas_status_invalid_value, "Rocblas error")); + EXPECT_TRUE(CheckCudaStatusSuccess(HIPFFT_SUCCESS)); + EXPECT_TRUE(CheckCudaStatusFailure(HIPFFT_INVALID_PLAN, "HIPFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(HIPFFT_ALLOC_FAILED, "HIPFFT error")); + #if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL) EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess)); EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError, "Rccl error")); @@ -418,6 +423,25 @@ TEST(enforce, cuda_success) { "negative vector size, for example).To correct: ensure that all the " "parameters being passed have valid values")); + EXPECT_TRUE(CheckCudaStatusSuccess(CUFFT_SUCCESS)); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_PLAN, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_ALLOC_FAILED, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_TYPE, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_VALUE, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INTERNAL_ERROR, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_EXEC_FAILED, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_SETUP_FAILED, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_SIZE, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_UNALIGNED_DATA, "CUFFT error")); + EXPECT_TRUE( + CheckCudaStatusFailure(CUFFT_INCOMPLETE_PARAMETER_LIST, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_DEVICE, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_PARSE_ERROR, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_NO_WORKSPACE, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_NOT_IMPLEMENTED, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_LICENSE_ERROR, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_NOT_SUPPORTED, "CUFFT error")); + #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess)); EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError, "NCCL error")); diff --git a/paddle/fluid/platform/external_error.proto b/paddle/fluid/platform/external_error.proto index 2094de7e10f69..cbbf803492e64 100644 --- a/paddle/fluid/platform/external_error.proto +++ b/paddle/fluid/platform/external_error.proto @@ -24,6 +24,7 @@ enum ApiType { CUBLAS = 3; CUSOLVER = 4; NCCL = 5; + CUFFT = 6; } message MessageDesc { diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 147dc3be4c154..351b6ecb9f780 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -64,7 +64,6 @@ import paddle.static # noqa: F401 import paddle.vision # noqa: F401 -from .tensor import fft from .tensor.random import bernoulli # noqa: F401 from .tensor.attribute import rank # noqa: F401 @@ -297,6 +296,8 @@ from .hapi import flops # noqa: F401 from . import hub # noqa: F401 from . import linalg # noqa: F401 +from . import fft # noqa: F401 +from . import signal # noqa: F401 import paddle.text # noqa: F401 import paddle.vision # noqa: F401 diff --git a/python/paddle/tensor/fft.py b/python/paddle/fft.py similarity index 97% rename from python/paddle/tensor/fft.py rename to python/paddle/fft.py index 98ca858c0eb85..7399ccc1ace59 100644 --- a/python/paddle/tensor/fft.py +++ b/python/paddle/fft.py @@ -15,30 +15,30 @@ from typing import Sequence import numpy as np import paddle -from .attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype -from ..fluid.framework import in_dygraph_mode -from .. import _C_ops -from ..fluid.data_feeder import check_variable_and_dtype -from ..fluid.layer_helper import LayerHelper +from .tensor.attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype +from .fluid.framework import in_dygraph_mode +from . import _C_ops +from .fluid.data_feeder import check_variable_and_dtype +from .fluid.layer_helper import LayerHelper __all__ = [ 'fft', - 'fft2', - 'fftn', 'ifft', - 'ifft2', - 'ifftn', 'rfft', - 'rfft2', - 'rfftn', 'irfft', - 'irfft2', - 'irfftn', 'hfft', - 'hfft2', - 'hfftn', 'ihfft', + 'fft2', + 'ifft2', + 'rfft2', + 'irfft2', + 'hfft2', 'ihfft2', + 'fftn', + 'ifftn', + 'rfftn', + 'irfftn', + 'hfftn', 'ihfftn', 'fftfreq', 'rfftfreq', @@ -362,7 +362,7 @@ def irfft(x, n=None, axis=-1, norm="backward", name=None): xp = paddle.to_tensor(x) irfft_xp = paddle.fft.irfft(xp).numpy() print(irfft_xp) - # [0. 0. 0. 4.] + # [0. 1. 0. 0.] """ return fft_c2r(x, n, axis, norm, forward=False, name=name) @@ -500,7 +500,7 @@ def fftn(x, s=None, axes=None, norm="backward", name=None): import numpy as np import paddle - x = x = np.mgrid[:4, :4, :4][1] + x = np.mgrid[:4, :4, :4][1] xp = paddle.to_tensor(x) fftn_xp = paddle.fft.fftn(xp, axes=(1, 2)).numpy() print(fftn_xp) @@ -654,9 +654,9 @@ def rfftn(x, s=None, axes=None, norm="backward", name=None): # use axes(2, 0) print(paddle.fft.rfftn(x, axes=(2, 0))) # Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, - # [[[(24+0j), 0j , 0j ], - # [0j , 0j , 0j ], - # [0j , 0j , 0j ]], + # [[[(8+0j), 0j , 0j ], + # [(8+0j), 0j , 0j ], + # [(8+0j), 0j , 0j ]], # # [[0j , 0j , 0j ], # [0j , 0j , 0j ], @@ -1135,7 +1135,24 @@ def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): refer to :ref:`api_guide_Name` . Returns: - out(Tensor) : The result of the inverse real 2-D FFT. + out(Tensor) : The result of the inverse hermitian 2-D FFT. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.mgrid[:5, :5][0].astype(np.float64) + xp = paddle.to_tensor(x) + ihfft2_xp = paddle.fft.ihfft2(xp).numpy() + print(ihfft2_xp) + # [[ 2. +0.j 0. +0.j 0. +0.j ] + # [-0.5-0.68819096j 0. +0.j 0. +0.j ] + # [-0.5-0.16245985j 0. +0.j 0. +0.j ] + # [-0.5+0.16245985j 0. +0.j 0. +0.j ] + # [-0.5+0.68819096j 0. +0.j 0. +0.j ]] """ _check_at_least_ndim(x, 2) if s is not None: @@ -1273,9 +1290,8 @@ def fftshift(x, axes=None, name=None): import paddle x = np.array([3, 1, 2, 2, 3], dtype=float) - scalar_temp = 0.3 n = x.size - fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp) + fftfreq_xp = paddle.fft.fftfreq(n, d=0.3) res = paddle.fft.fftshift(fftfreq_xp).numpy() print(res) # [-1.3333334 -0.6666667 0. 0.6666667 1.3333334] @@ -1284,13 +1300,13 @@ def fftshift(x, axes=None, name=None): shape = paddle.shape(x) if axes is None: # shift all axes - rank = paddle.rank(x).reshape([1]) - axes = axes or paddle.arange(0, rank) - shifts = [size // 2 for size in shape] + rank = len(x.shape) + axes = list(range(0, rank)) + shifts = shape // 2 elif isinstance(axes, int): shifts = shape[axes] // 2 else: - shifts = [shape[ax] // 2 for ax in axes] + shifts = paddle.concat([shape[ax] // 2 for ax in axes]) return paddle.roll(x, shifts, axes, name=name) @@ -1317,9 +1333,8 @@ def ifftshift(x, axes=None, name=None): import paddle x = np.array([3, 1, 2, 2, 3], dtype=float) - scalar_temp = 0.3 n = x.size - fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp) + fftfreq_xp = paddle.fft.fftfreq(n, d=0.3) res = paddle.fft.ifftshift(fftfreq_xp).numpy() print(res) # [ 1.3333334 -1.3333334 -0.6666667 0. 0.6666667] @@ -1328,13 +1343,13 @@ def ifftshift(x, axes=None, name=None): shape = paddle.shape(x) if axes is None: # shift all axes - rank = paddle.rank(x).reshape([1]) - axes = axes or paddle.arange(0, rank) - shifts = [-size // 2 for size in shape] + rank = len(x.shape) + axes = list(range(0, rank)) + shifts = shape // 2 elif isinstance(axes, int): shifts = -shape[axes] // 2 else: - shifts = [-shape[ax] // 2 for ax in axes] + shifts = paddle.concat([-shape[ax] // 2 for ax in axes]) return paddle.roll(x, shifts, axes, name=name) @@ -1346,7 +1361,7 @@ def fft_c2c(x, n, axis, norm, forward, name): x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) _check_normalization(norm) - axis = axis or -1 + axis = axis if axis is not None else -1 _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) @@ -1376,7 +1391,7 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name): if is_interger(x): x = paddle.cast(x, paddle.get_default_dtype()) _check_normalization(norm) - axis = axis or -1 + axis = axis if axis is not None else -1 _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) @@ -1415,7 +1430,7 @@ def fft_c2r(x, n, axis, norm, forward, name): elif is_floating_point(x): x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) _check_normalization(norm) - axis = axis or -1 + axis = axis if axis is not None else -1 _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index c83c943217d4e..604de11521b7d 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -1009,10 +1009,11 @@ def test_rfftfreq(self): @place(DEVICES) -@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ - ('test_1d', np.random.randn(10), (0, ), 'float64'), - ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), -]) +@parameterize( + (TEST_CASE_NAME, 'x', 'axes', 'dtype'), + [('test_1d', np.random.randn(10), (0, ), 'float64'), + ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64')]) class TestFftShift(unittest.TestCase): def test_fftshift(self): """Test fftshift with norm condition @@ -1030,6 +1031,7 @@ def test_fftshift(self): @parameterize((TEST_CASE_NAME, 'x', 'axes'), [ ('test_1d', np.random.randn(10), (0, ), 'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), ]) class TestIfftShift(unittest.TestCase): def test_ifftshift(self): diff --git a/python/paddle/fluid/tests/unittests/test_roll_op.py b/python/paddle/fluid/tests/unittests/test_roll_op.py index 99121d2953a14..bca7665b814db 100644 --- a/python/paddle/fluid/tests/unittests/test_roll_op.py +++ b/python/paddle/fluid/tests/unittests/test_roll_op.py @@ -122,6 +122,34 @@ def test_axis_out_range(): self.assertRaises(ValueError, test_axis_out_range) + def test_shifts_as_tensor_dygraph(self): + with fluid.dygraph.guard(): + x = paddle.arange(9).reshape([3, 3]) + shape = paddle.shape(x) + shifts = shape // 2 + axes = [0, 1] + out = paddle.roll(x, shifts=shifts, axis=axes).numpy() + expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]]) + self.assertTrue(np.allclose(out, expected_out)) + + def test_shifts_as_tensor_static(self): + with program_guard(Program(), Program()): + x = paddle.arange(9).reshape([3, 3]).astype('float32') + shape = paddle.shape(x) + shifts = shape // 2 + axes = [0, 1] + out = paddle.roll(x, shifts=shifts, axis=axes) + expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]]) + + exe = fluid.Executor(fluid.CPUPlace()) + [out_np] = exe.run(fetch_list=[out]) + self.assertTrue(np.allclose(out_np, expected_out)) + + if paddle.is_compiled_with_cuda(): + exe = fluid.Executor(fluid.CPUPlace()) + [out_np] = exe.run(fetch_list=[out]) + self.assertTrue(np.allclose(out_np, expected_out)) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_signal.py b/python/paddle/fluid/tests/unittests/test_signal.py index a109a5aa5d1a6..ecbbd8f52db9b 100644 --- a/python/paddle/fluid/tests/unittests/test_signal.py +++ b/python/paddle/fluid/tests/unittests/test_signal.py @@ -652,7 +652,7 @@ def test_frame(self): self.assertTrue( np.allclose( frame_for_api_test(self.x, self.frame_length, self.hop_length, self.axis), - paddle.tensor.signal.frame( + paddle.signal.frame( paddle.to_tensor(self.x), self.frame_length, self.hop_length, @@ -678,7 +678,7 @@ def test_frame_static(self): mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype) - output = paddle.tensor.signal.frame( + output = paddle.signal.frame( input, self.frame_length, self.hop_length, @@ -708,7 +708,7 @@ def test_frame_static(self): class TestFrameException(unittest.TestCase): def test_frame(self): with self.assertRaises(self.expect_exception): - paddle.tensor.signal.frame( + paddle.signal.frame( paddle.to_tensor(self.x), self.frame_length, self.hop_length, @@ -731,7 +731,7 @@ def test_overlap_add(self): self.assertTrue( np.allclose( overlap_add_for_api_test(self.x, self.hop_length, self.axis), - paddle.tensor.signal.overlap_add( + paddle.signal.overlap_add( paddle.to_tensor(self.x), self.hop_length, self.axis), @@ -756,7 +756,7 @@ def test_overlap_add_static(self): mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype) - output = paddle.tensor.signal.overlap_add( + output = paddle.signal.overlap_add( input, self.hop_length, self.axis), @@ -783,7 +783,7 @@ def test_overlap_add_static(self): class TestOverlapAddException(unittest.TestCase): def test_overlap_add(self): with self.assertRaises(self.expect_exception): - paddle.tensor.signal.overlap_add( + paddle.signal.overlap_add( paddle.to_tensor(self.x), self.hop_length, self.axis) @@ -848,7 +848,7 @@ def test_stft(self): self.assertTrue( np.allclose( stft(self.x, self.n_fft, self.hop_length, self.win_length, win_l, self.center, self.pad_mode), - paddle.tensor.signal.stft( + paddle.signal.stft( paddle.to_tensor(self.x), self.n_fft, self.hop_length, @@ -891,7 +891,7 @@ def test_stft(self): win_p = paddle.to_tensor(self.window) with self.assertRaises(self.expect_exception): - paddle.tensor.signal.stft( + paddle.signal.stft( paddle.to_tensor(self.x), self.n_fft, self.hop_length, @@ -934,7 +934,7 @@ def test_istft(self): self.assertTrue( np.allclose( istft(self.x, self.hop_length, self.win_length, win_l, self.center, self.length), - paddle.tensor.signal.istft( + paddle.signal.istft( paddle.to_tensor(self.x), self.n_fft, self.hop_length, @@ -986,7 +986,7 @@ def test_istft(self): win_p = paddle.to_tensor(self.window) with self.assertRaises(self.expect_exception): - paddle.tensor.signal.istft( + paddle.signal.istft( paddle.to_tensor(self.x), self.n_fft, self.hop_length, diff --git a/python/paddle/tensor/signal.py b/python/paddle/signal.py similarity index 97% rename from python/paddle/tensor/signal.py rename to python/paddle/signal.py index 86022a1748356..fc80c7cbc80f3 100644 --- a/python/paddle/tensor/signal.py +++ b/python/paddle/signal.py @@ -16,16 +16,14 @@ import paddle -from .attribute import is_complex, is_floating_point +from .tensor.attribute import is_complex, is_floating_point from .fft import fft_r2c, fft_c2r, fft_c2c -from ..fluid.data_feeder import check_variable_and_dtype -from ..fluid.framework import in_dygraph_mode -from ..fluid.layer_helper import LayerHelper -from .. import _C_ops +from .fluid.data_feeder import check_variable_and_dtype +from .fluid.framework import in_dygraph_mode +from .fluid.layer_helper import LayerHelper +from . import _C_ops __all__ = [ - 'frame', - 'overlap_add', 'stft', 'istft', ] @@ -56,7 +54,7 @@ def frame(x, frame_length, hop_length, axis=-1, name=None): .. code-block:: python import paddle - from paddle.tensor.signal import frame + from paddle.signal import frame # 1D x = paddle.arange(8) @@ -177,7 +175,7 @@ def overlap_add(x, hop_length, axis=-1, name=None): .. code-block:: python import paddle - from paddle.tensor.signal import overlap_add + from paddle.signal import overlap_add # 2D x0 = paddle.arange(16).reshape([8, 2]) @@ -291,11 +289,11 @@ def stft(x, real-valued input and `onesided` is `True`) or `[..., n_fft, num_frames]`( `onesided` is `False`) - Exampels: + Examples: .. code-block:: python import paddle - from paddle.tensor.signal import stft + from paddle.signal import stft # real-valued input x = paddle.randn([8, 48000], dtype=paddle.float64) @@ -415,7 +413,7 @@ def istft(x, - :math:`N`: Value of `n_fft`. - :math:`H`: Value of `hop_length`. - Result of `istft` expected to be the inverse of `paddle.tensor.signal.stft`, but it is + Result of `istft` expected to be the inverse of `paddle.signal.stft`, but it is not guaranteed to reconstruct a exactly realizible time-domain signal from a STFT complex tensor which has been modified (via masking or otherwise). Therefore, `istft` gives the [Griffin-Lim optimal estimate](https://ieeexplore.ieee.org/document/1164317) @@ -454,12 +452,12 @@ def istft(x, A tensor of least squares estimation of the reconstructed signal(s) with shape `[..., seq_length]` - Exampels: + Examples: .. code-block:: python import numpy as np import paddle - from paddle.tensor.signal import stft, istft + from paddle.signal import stft, istft paddle.seed(0) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 26a44fc39def9..d046b666c3ef3 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -222,8 +222,6 @@ from .array import create_array # noqa: F401 from .einsum import einsum # noqa: F401 -from . import fft -from . import signal #this list used in math_op_patch.py for _binary_creator_ tensor_method_func = [ #noqa diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 5f7588cb2a9a0..9b9b2d9431eeb 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -696,15 +696,24 @@ def roll(x, shifts, axis=None, name=None): helper = LayerHelper("roll", **locals()) check_type(axis, 'axis', (list, tuple), 'roll') - check_type(shifts, 'shifts', (list, tuple), 'roll') + out = helper.create_variable_for_type_inference(x.dtype) - helper.append_op( - type='roll', - inputs={'X': x}, - outputs={'Out': out}, - attrs={'axis': axis, - 'shifts': shifts}) + if isinstance(shifts, Variable): + helper.append_op( + type='roll', + inputs={'X': x, + "ShiftsTensor": shifts}, + outputs={'Out': out}, + attrs={'axis': axis}) + else: + check_type(shifts, 'shifts', (list, tuple), 'roll') + helper.append_op( + type='roll', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'axis': axis, + 'shifts': shifts}) return out diff --git a/tools/externalError/README.md b/tools/externalError/README.md index 029efd8cb9491..0c2ac626991da 100644 --- a/tools/externalError/README.md +++ b/tools/externalError/README.md @@ -1,9 +1,25 @@ -Usage: +#### **Introduction for crawling new error message:** -Please run: -``` -bash start.sh -``` -If you want to update all external error message, you need to run command `bash start.sh` in current directory, -and upload the generated file `externalErrorMsg.tar.gz` to https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg.tar.gz + +1. add new spider code in spider.py for crawling error message from website. + +2. run `bash start.sh` in current directory to generate new externalErrorMsg_${date}.tar.gz file, for example `externalErrorMsg_20210928.tar.gz`. + +3. upload above tar file into bos https://paddlepaddledeps.bj.bcebos.com **paddlepaddledeps** bucket, and copy download link `${download_url}`. ***\*Be careful not to delete original tar file\****. + +4. compute md5 value of above tar file `${md5}`, and modify cmake/third_party.cmake file + + ``` + set(URL "${download_url}" CACHE STRING "" FORCE) + file_download_and_uncompress(${URL} "externalError" MD5 ${md5}) + ``` + + for example: + + ``` + set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg_20210928.tar.gz" CACHE STRING "" FORCE) + file_download_and_uncompress(${URL} "externalError" MD5 a712a49384e77ca216ad866712f7cafa) + ``` + +5. commit your changes, and create pull request. diff --git a/tools/externalError/spider.py b/tools/externalError/spider.py index a74d82f40ebeb..e07f05f561cb5 100644 --- a/tools/externalError/spider.py +++ b/tools/externalError/spider.py @@ -17,8 +17,10 @@ import urllib.request import json import collections -import sys, getopt +import sys +import getopt import external_error_pb2 +from html.parser import HTMLParser def parsing(externalErrorDesc): @@ -335,6 +337,31 @@ def parsing(externalErrorDesc): _Messages.message = "'%s'. %s" % (error[0], m_message) print("End crawling errorMessage for nvidia NCCL API!\n") + #*************************************************************************************************# + #*********************************** CUFFT Error Message **************************************# + print("start crawling errorMessage for nvidia CUFFT API--->") + url = 'https://docs.nvidia.com/cuda/cufft/index.html#cufftresult' + + allMessageDesc = externalErrorDesc.errors.add() + allMessageDesc.type = external_error_pb2.CUFFT + + html = urllib.request.urlopen(url).read().decode('utf-8') + + class CUFFTHTMLParser(HTMLParser): + '''CUFFTHTML Parser + ''' + + def handle_data(self, data): + if 'typedef enum cufftResult_t' in data: + for line in data.strip().splitlines()[1:-1]: + status, code, desc = re.split('=|//', line.strip()) + _Messages = allMessageDesc.messages.add() + _Messages.code = int(code.strip(' ,')) + _Messages.message = "'%s'. %s" % (status.strip(), + desc.strip()) + + CUFFTHTMLParser().feed(html) + def main(argv): try: diff --git a/tools/externalError/start.sh b/tools/externalError/start.sh index 32ef63c261268..82715dd47326c 100644 --- a/tools/externalError/start.sh +++ b/tools/externalError/start.sh @@ -32,4 +32,4 @@ fi protobuf/bin/protoc -I../../paddle/fluid/platform/ --python_out . ../../paddle/fluid/platform/external_error.proto python3.7 spider.py -tar czvf externalErrorMsg.tar.gz externalErrorMsg.pb +tar czvf externalErrorMsg_$(date +'%Y%m%d').tar.gz externalErrorMsg.pb