From 937b9635ce7ea73be20381efc166ec35e9f09d06 Mon Sep 17 00:00:00 2001 From: JakubO Date: Thu, 18 May 2023 20:25:24 +0200 Subject: [PATCH 01/14] add MotherWavelet helper and WaveletGpu kernel --- dali/kernels/signal/CMakeLists.txt | 1 + dali/kernels/signal/wavelet/CMakeLists.txt | 17 ++ dali/kernels/signal/wavelet/mother_wavelet.cu | 169 ++++++++++++++++++ .../kernels/signal/wavelet/mother_wavelet.cuh | 63 +++++++ dali/kernels/signal/wavelet/wavelet_gpu.cu | 90 ++++++++++ dali/kernels/signal/wavelet/wavelet_gpu.cuh | 51 ++++++ 6 files changed, 391 insertions(+) create mode 100644 dali/kernels/signal/wavelet/CMakeLists.txt create mode 100644 dali/kernels/signal/wavelet/mother_wavelet.cu create mode 100644 dali/kernels/signal/wavelet/mother_wavelet.cuh create mode 100644 dali/kernels/signal/wavelet/wavelet_gpu.cu create mode 100644 dali/kernels/signal/wavelet/wavelet_gpu.cuh diff --git a/dali/kernels/signal/CMakeLists.txt b/dali/kernels/signal/CMakeLists.txt index 431ae39629..74ca2e8970 100644 --- a/dali/kernels/signal/CMakeLists.txt +++ b/dali/kernels/signal/CMakeLists.txt @@ -17,6 +17,7 @@ add_subdirectory(decibel) if (BUILD_FFTS) add_subdirectory(fft) endif() +add_subdirectory(wavelet) add_subdirectory(window) collect_headers(DALI_INST_HDRS PARENT_SCOPE) diff --git a/dali/kernels/signal/wavelet/CMakeLists.txt b/dali/kernels/signal/wavelet/CMakeLists.txt new file mode 100644 index 0000000000..f3a24faa7c --- /dev/null +++ b/dali/kernels/signal/wavelet/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +collect_headers(DALI_INST_HDRS PARENT_SCOPE) +collect_sources(DALI_KERNEL_SRCS PARENT_SCOPE) +collect_test_sources(DALI_KERNEL_TEST_SRCS PARENT_SCOPE) \ No newline at end of file diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cu b/dali/kernels/signal/wavelet/mother_wavelet.cu new file mode 100644 index 0000000000..664d7aab69 --- /dev/null +++ b/dali/kernels/signal/wavelet/mother_wavelet.cu @@ -0,0 +1,169 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "dali/kernels/signal/wavelet/mother_wavelet.cuh" + +namespace dali { +namespace kernels { +namespace signal { + +template +__device__ +T HaarWavelet(T t, T a, T b) { + T x = std::pow(2.0, a) - b; + if (0.0 <= x && x < 0.5) { + return std::pow(2.0, a / 2.0); + } + if (0.5 <= x && x < 1.0) { + return -std::pow(2.0, a / 2.0); + } + return 0.0; +} + +template +__device__ +T DaubechiesWavelet(T t, T a, T b) { + return 0.0; +} + +template +__device__ +T SymletWavelet(T t, T a, T b) { + return 0.0; +} + +template +__device__ +T CoifletWavelet(T t, T a, T b) { + return 0.0; +} + +template +__device__ +T BiorthogonalWavelet(T t, T a, T b) { + return 0.0; +} + +template +__device__ +T MeyerWavelet(T t, T a, T b) { + return 0.0; +} + +template +__device__ +T GaussianWavelet(T t, T a, T b) { + return 0.0; +} + +template +__device__ +T MexicanHatWavelet(T t, T a, T b) { + return 0.0; +} + +template +__device__ +T MorletWavelet(T t, T a, T b) { + return 0.0; +} + +template +__device__ +T ComplexGaussianWavelet(T t, T a, T b) { + return 0.0; +} + +template +__device__ +T ShannonWavelet(T t, T a, T b) { + return 0.0; +} + +template +__device__ +T FbspWavelet(T t, T a, T b) { + return 0.0; +} + +template +__device__ +T ComplexMorletWavelet(T t, T a, T b) { + return 0.0; +} + +template +MotherWavelet::MotherWavelet(const WaveletName& name) { + switch(name) { + case WaveletName::HAAR: + waveletFunc = &HaarWavelet; + break; + + case WaveletName::DB: + waveletFunc = &DaubechiesWavelet; + break; + + case WaveletName::SYM: + waveletFunc = &SymletWavelet; + break; + + case WaveletName::COIF: + waveletFunc = &CoifletWavelet; + break; + + case WaveletName::BIOR: + waveletFunc = &BiorthogonalWavelet; + break; + + case WaveletName::MEY: + waveletFunc = &MeyerWavelet; + break; + + case WaveletName::GAUS: + waveletFunc = &GaussianWavelet; + break; + + case WaveletName::MEXH: + waveletFunc = &MexicanHatWavelet; + break; + + case WaveletName::MORL: + waveletFunc = &MorletWavelet; + break; + + case WaveletName::CGAU: + waveletFunc = &ComplexGaussianWavelet; + break; + + case WaveletName::SHAN: + waveletFunc = &ShannonWavelet; + break; + + case WaveletName::FBSP: + waveletFunc = &FbspWavelet; + break; + + case WaveletName::CMOR: + waveletFunc = &ComplexMorletWavelet; + break; + + default: + throw new std::invalid_argument("Unknown wavelet name."); + } +} + +} // namespace signal +} // namespace kernel +} // namespace dali diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cuh b/dali/kernels/signal/wavelet/mother_wavelet.cuh new file mode 100644 index 0000000000..76eddd0278 --- /dev/null +++ b/dali/kernels/signal/wavelet/mother_wavelet.cuh @@ -0,0 +1,63 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ +#define DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ + +#include "dali/core/common.h" +#include "dali/core/error_handling.h" +#include "dali/core/format.h" +#include "dali/core/util.h" +#include "dali/kernels/kernel.h" + +namespace dali { +namespace kernels { +namespace signal { + +enum class WaveletName { + HAAR, + DB, + SYM, + COIF, + BIOR, + MEY, + GAUS, + MEXH, + MORL, + CGAU, + SHAN, + FBSP, + CMOR +}; + +template +class MotherWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + + public: + MotherWavelet(const WaveletName &name); + ~MotherWavelet() = default; + + __device__ T (*waveletFunc)(T t, T a, T b); +}; + +template class MotherWavelet; +template class MotherWavelet; + +} // namespace signal +} // namespace kernel +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cu b/dali/kernels/signal/wavelet/wavelet_gpu.cu new file mode 100644 index 0000000000..9ebf4512c2 --- /dev/null +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cu @@ -0,0 +1,90 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/kernels/signal/wavelet/wavelet_gpu.cuh" +#include +#include +#include +#include "dali/core/common.h" +#include "dali/core/error_handling.h" +#include "dali/core/format.h" +#include "dali/kernels/kernel.h" +#include "dali/kernels/signal/wavelet/mother_wavelet.cuh" + +namespace dali { +namespace kernels { +namespace signal { + +template +struct SampleDesc { + const T *a = nullptr; + int64_t size_a = 0; + T *out = nullptr; + int64_t size_out = 0; +}; + +template +__global__ void ComputeWavelet(const SampleDesc* sample_data, + T begin, T sampling_rate, T b, MotherWavelet wavelet) { + const int64_t block_size = blockDim.x * blockDim.y; + const int64_t tid = threadIdx.y * blockDim.x + threadIdx.x; + const T t = begin + (T)tid / sampling_rate; + const T a = sample_data->a[blockIdx.x]; + sample_data->out[tid + blockIdx.x * block_size] = wavelet.waveletFunc(t, a, b); +} + +template +WaveletGpu::~WaveletGpu() = default; + +template +KernelRequirements WaveletGpu::Setup(KernelContext &context, + const WaveletArgs &args) { + ScratchpadEstimator se; + se.add>(1); + se.add>(1); + KernelRequirements req; + req.scratch_sizes = se.sizes; + return req; +} + +template +void WaveletGpu::Run(KernelContext &context, + const OutListGPU &out, + const InListGPU &a, + const WaveletArgs &args) { + auto* sample_data = context.scratchpad->AllocateHost>(1); + + sample_data[0].out = out.tensor_data(0); + sample_data[0].a = a.tensor_data(0); + sample_data[0].size_a = volume(a.tensor_shape(0)); + auto in_size = (args.end - args.begin) * args.sampling_rate; + sample_data[0].size_out = in_size * sample_data[0].size_a; + + auto* sample_data_gpu = context.scratchpad->AllocateGPU>(1); + CUDA_CALL( + cudaMemcpyAsync(sample_data_gpu, sample_data, sizeof(SampleDesc), + cudaMemcpyHostToDevice, context.gpu.stream)); + + dim3 block(sample_data[0].size_a); + dim3 grid(in_size); + ComputeWavelet<<>>( + sample_data_gpu, args.begin, args.sampling_rate, args.b, MotherWavelet(args.wavelet)); +} + +template class WaveletGpu; +template class WaveletGpu; + +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cuh b/dali/kernels/signal/wavelet/wavelet_gpu.cuh new file mode 100644 index 0000000000..f25cbc9d16 --- /dev/null +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cuh @@ -0,0 +1,51 @@ +// Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_ +#define DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_ + +#include +#include "dali/core/common.h" +#include "dali/core/error_handling.h" +#include "dali/core/format.h" +#include "dali/core/util.h" +#include "dali/kernels/kernel.h" +#include "dali/kernels/signal/wavelet/wavelet_args.h" + +namespace dali { +namespace kernels { +namespace signal { + +template +class DLL_PUBLIC WaveletGpu { + public: + static_assert(std::is_floating_point::value, + "Only floating point types are supported"); + + DLL_PUBLIC ~WaveletGpu(); + + DLL_PUBLIC KernelRequirements Setup(KernelContext &context, + const WaveletArgs &args); + + DLL_PUBLIC void Run(KernelContext &context, + const OutListGPU &out, + const InListGPU &a, + const WaveletArgs &args); +}; + +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_ From cf7b6a6f46aef6b2ab26378088ed3971cd5041ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Wdowski?= Date: Thu, 18 May 2023 21:39:30 +0200 Subject: [PATCH 02/14] Cwt WIP --- dali/kernels/signal/CMakeLists.txt | 1 + dali/kernels/signal/wavelets/CMakeLists.txt | 17 ++++ dali/kernels/signal/wavelets/cwt_args.h | 33 +++++++ dali/kernels/signal/wavelets/cwt_gpu.cu | 98 +++++++++++++++++++ dali/kernels/signal/wavelets/cwt_gpu.h | 50 ++++++++++ dali/operators/signal/CMakeLists.txt | 1 + dali/operators/signal/wavelets/CMakeLists.txt | 17 ++++ dali/operators/signal/wavelets/cwt_op.h | 65 ++++++++++++ dali/operators/signal/wavelets/cwt_op_gpu.cu | 80 +++++++++++++++ 9 files changed, 362 insertions(+) create mode 100644 dali/kernels/signal/wavelets/CMakeLists.txt create mode 100644 dali/kernels/signal/wavelets/cwt_args.h create mode 100644 dali/kernels/signal/wavelets/cwt_gpu.cu create mode 100644 dali/kernels/signal/wavelets/cwt_gpu.h create mode 100644 dali/operators/signal/wavelets/CMakeLists.txt create mode 100644 dali/operators/signal/wavelets/cwt_op.h create mode 100644 dali/operators/signal/wavelets/cwt_op_gpu.cu diff --git a/dali/kernels/signal/CMakeLists.txt b/dali/kernels/signal/CMakeLists.txt index 431ae39629..07b62d5342 100644 --- a/dali/kernels/signal/CMakeLists.txt +++ b/dali/kernels/signal/CMakeLists.txt @@ -17,6 +17,7 @@ add_subdirectory(decibel) if (BUILD_FFTS) add_subdirectory(fft) endif() +add_subdirectory(wavelets) add_subdirectory(window) collect_headers(DALI_INST_HDRS PARENT_SCOPE) diff --git a/dali/kernels/signal/wavelets/CMakeLists.txt b/dali/kernels/signal/wavelets/CMakeLists.txt new file mode 100644 index 0000000000..c3ee135e61 --- /dev/null +++ b/dali/kernels/signal/wavelets/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +collect_headers(DALI_INST_HDRS PARENT_SCOPE) +collect_sources(DALI_KERNEL_SRCS PARENT_SCOPE) +collect_test_sources(DALI_KERNEL_TEST_SRCS PARENT_SCOPE) \ No newline at end of file diff --git a/dali/kernels/signal/wavelets/cwt_args.h b/dali/kernels/signal/wavelets/cwt_args.h new file mode 100644 index 0000000000..14f6cb7d5b --- /dev/null +++ b/dali/kernels/signal/wavelets/cwt_args.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_WAVELETS_CWT_ARGS_H_ +#define DALI_KERNELS_SIGNAL_WAVELETS_CWT_ARGS_H_ + +namespace dali { +namespace kernels { +namespace signal { +namespace wavelets { + +template +struct CwtArgs { + T a; +}; + +} // namespace wavelets +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_WAVELETS_CWT_ARGS_H_ diff --git a/dali/kernels/signal/wavelets/cwt_gpu.cu b/dali/kernels/signal/wavelets/cwt_gpu.cu new file mode 100644 index 0000000000..be2b19bde2 --- /dev/null +++ b/dali/kernels/signal/wavelets/cwt_gpu.cu @@ -0,0 +1,98 @@ +// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "dali/core/common.h" +#include "dali/core/error_handling.h" +#include "dali/core/format.h" +#include "dali/kernels/kernel.h" +#include "dali/kernels/signal/wavelets/cwt_args.h" +#include "dali/kernels/signal/wavelets/cwt_gpu.h" + +namespace dali { +namespace kernels { +namespace signal { +namespace wavelets { + +template +struct SampleDesc { + const T *in = nullptr; + T *out = nullptr; + int64_t size = 0; +}; + +template +__global__ void CwtKernel(const SampleDesc *sample_data, CwtArgs args) { + const int64_t block_size = blockDim.y * blockDim.x; + const int64_t grid_size = gridDim.x * block_size; + const int sample_idx = blockIdx.y; + const auto sample = sample_data[sample_idx]; + const int64_t offset = block_size * blockIdx.x; + const int64_t tid = threadIdx.y * blockDim.x + threadIdx.x; + + for (int64_t idx = offset + tid; idx < sample.size; idx += grid_size) { + sample.out[idx] = sample.in[idx] * args.a; + } +} + +template +CwtGpu::~CwtGpu() = default; + +template +KernelRequirements CwtGpu::Setup(KernelContext &context, + const InListGPU &in) { + auto out_shape = in.shape; + const size_t num_samples = in.size(); + ScratchpadEstimator se; + se.add>(num_samples); + se.add>(num_samples); + KernelRequirements req; + req.scratch_sizes = se.sizes; + req.output_shapes = {out_shape}; + return req; +} + +template +void CwtGpu::Run(KernelContext &context, const OutListGPU &out, + const InListGPU &in, const CwtArgs &args) { + auto num_samples = in.size(); + auto *sample_data = context.scratchpad->AllocateHost>(num_samples); + + for (int i = 0; i < num_samples; i++) { + auto &sample = sample_data[i]; + sample.out = out.tensor_data(i); + sample.in = in.tensor_data(i); + sample.size = volume(in.tensor_shape(i)); + assert(sample.size == volume(out.tensor_shape(i))); + } + + auto *sample_data_gpu = context.scratchpad->AllocateGPU>(num_samples); + CUDA_CALL(cudaMemcpyAsync(sample_data_gpu, sample_data, num_samples * sizeof(SampleDesc), + cudaMemcpyHostToDevice, context.gpu.stream)); + + dim3 block(32, 32); + auto blocks_per_sample = std::max(32, 1024 / num_samples); + dim3 grid(blocks_per_sample, num_samples); + CwtKernel<<>>(sample_data_gpu, args); +} + +template class CwtGpu; +template class CwtGpu; + +} // namespace wavelets +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/kernels/signal/wavelets/cwt_gpu.h b/dali/kernels/signal/wavelets/cwt_gpu.h new file mode 100644 index 0000000000..2fa4fd939e --- /dev/null +++ b/dali/kernels/signal/wavelets/cwt_gpu.h @@ -0,0 +1,50 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_WAVELETS_CWT_GPU_H_ +#define DALI_KERNELS_SIGNAL_WAVELETS_CWT_GPU_H_ + +#include +#include "dali/core/common.h" +#include "dali/core/error_handling.h" +#include "dali/core/format.h" +#include "dali/core/util.h" +#include "dali/kernels/kernel.h" +#include "dali/kernels/signal/wavelets/cwt_args.h" + +namespace dali { +namespace kernels { +namespace signal { +namespace wavelets { + +template +class DLL_PUBLIC CwtGpu { + public: + static_assert(std::is_floating_point::value, "Only floating point types are supported"); + + DLL_PUBLIC ~CwtGpu(); + + DLL_PUBLIC KernelRequirements Setup(KernelContext &context, + const InListGPU &in); + + DLL_PUBLIC void Run(KernelContext &context, const OutListGPU &out, + const InListGPU &in, const CwtArgs &args); +}; + +} // namespace wavelets +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_WAVELETS_CWT_GPU_H_ diff --git a/dali/operators/signal/CMakeLists.txt b/dali/operators/signal/CMakeLists.txt index 217f785aa2..c16a5d4687 100644 --- a/dali/operators/signal/CMakeLists.txt +++ b/dali/operators/signal/CMakeLists.txt @@ -16,6 +16,7 @@ add_subdirectory(decibel) if (BUILD_FFTS) add_subdirectory(fft) endif() +add_subdirectory(wavelets) collect_headers(DALI_INST_HDRS PARENT_SCOPE) collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) diff --git a/dali/operators/signal/wavelets/CMakeLists.txt b/dali/operators/signal/wavelets/CMakeLists.txt new file mode 100644 index 0000000000..0dba230abf --- /dev/null +++ b/dali/operators/signal/wavelets/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +collect_headers(DALI_INST_HDRS PARENT_SCOPE) +collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) +collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE) diff --git a/dali/operators/signal/wavelets/cwt_op.h b/dali/operators/signal/wavelets/cwt_op.h new file mode 100644 index 0000000000..3d6e439d49 --- /dev/null +++ b/dali/operators/signal/wavelets/cwt_op.h @@ -0,0 +1,65 @@ +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_ +#define DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_ + +#include +#include +#include "dali/core/common.h" +#include "dali/kernels/signal/wavelets/cwt_args.h" +#include "dali/pipeline/operator/common.h" +#include "dali/pipeline/operator/operator.h" +#include "dali/pipeline/util/operator_impl_utils.h" + +namespace dali { + +template +class Cwt : public Operator { + public: + explicit Cwt(const OpSpec &spec) : Operator(spec) { + if (!spec.HasArgument("a")) { + DALI_ENFORCE("`a` argument must be provided."); + } + args_.a = spec.GetArgument("a"); + } + + protected: + bool CanInferOutputs() const override { + return true; + } + + bool SetupImpl(std::vector &output_desc, const Workspace &ws) override { + assert(impl_ != nullptr); + return impl_->SetupImpl(output_desc, ws); + } + + void RunImpl(Workspace &ws) override { + assert(impl_ != nullptr); + impl_->RunImpl(ws); + } + + USE_OPERATOR_MEMBERS(); + using Operator::RunImpl; + + kernels::KernelManager kmgr_; + kernels::signal::wavelets::CwtArgs args_; + + std::unique_ptr> impl_; + DALIDataType type_ = DALI_NO_TYPE; +}; + +} // namespace dali + +#endif // DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_ diff --git a/dali/operators/signal/wavelets/cwt_op_gpu.cu b/dali/operators/signal/wavelets/cwt_op_gpu.cu new file mode 100644 index 0000000000..3cea5427d6 --- /dev/null +++ b/dali/operators/signal/wavelets/cwt_op_gpu.cu @@ -0,0 +1,80 @@ +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "dali/core/static_switch.h" +#include "dali/kernels/kernel_manager.h" +#include "dali/kernels/kernel_params.h" +#include "dali/kernels/signal/wavelets/cwt_args.h" +#include "dali/kernels/signal/wavelets/cwt_gpu.h" +#include "dali/operators/signal/wavelets/cwt_op.h" +#include "dali/pipeline/data/views.h" + +namespace dali { + +DALI_SCHEMA(Cwt).DocStr("by MW").NumInput(1).NumOutput(1).AddArg("a", "costam", + type2id::value); + +template +struct CwtImplGPU : public OpImplBase { + public: + using CwtArgs = kernels::signal::wavelets::CwtArgs; + using CwtKernel = kernels::signal::wavelets::CwtGpu; + + explicit CwtImplGPU(CwtArgs args) : args_(std::move(args)) { + kmgr_cwt_.Resize(1); + } + + bool SetupImpl(std::vector &output_desc, const Workspace &ws) override { + const auto &input = ws.Input(0); + auto in_view = view(input); + + auto type = type2id::value; + + kernels::KernelContext ctx; + ctx.gpu.stream = ws.stream(); + + auto &req = kmgr_cwt_.Setup(0, ctx, in_view); + output_desc.resize(1); + output_desc[0].type = type; + output_desc[0].shape = req.output_shapes[0]; + + return true; + } + + void RunImpl(Workspace &ws) override { + const auto &input = ws.Input(0); + auto &output = ws.Output(0); + + auto in_view = view(input); + auto out_view = view(output); + + kernels::KernelContext ctx; + ctx.gpu.stream = ws.stream(); + + kmgr_cwt_.Run(0, ctx, out_view, in_view, args_); + } + + private: + CwtArgs args_; + kernels::KernelManager kmgr_cwt_; + std::vector cwt_out_desc_; + TensorList cwt_out_; +}; + +DALI_REGISTER_OPERATOR(Cwt, Cwt, GPU); + +} // namespace dali From b0346197840c83a654143f6e34d55d1cef1a8f36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Wdowski?= Date: Thu, 18 May 2023 22:20:41 +0200 Subject: [PATCH 03/14] Rename namespace --- dali/kernels/signal/CMakeLists.txt | 2 +- .../signal/{wavelets => wavelet}/CMakeLists.txt | 0 dali/kernels/signal/{wavelets => wavelet}/cwt_args.h | 10 +++++----- dali/kernels/signal/{wavelets => wavelet}/cwt_gpu.cu | 4 ++-- dali/kernels/signal/{wavelets => wavelet}/cwt_gpu.h | 10 +++++----- dali/operators/signal/CMakeLists.txt | 2 +- .../signal/{wavelets => wavelet}/CMakeLists.txt | 0 dali/operators/signal/{wavelets => wavelet}/cwt_op.h | 0 .../signal/{wavelets => wavelet}/cwt_op_gpu.cu | 0 9 files changed, 14 insertions(+), 14 deletions(-) rename dali/kernels/signal/{wavelets => wavelet}/CMakeLists.txt (100%) rename dali/kernels/signal/{wavelets => wavelet}/cwt_args.h (80%) rename dali/kernels/signal/{wavelets => wavelet}/cwt_gpu.cu (98%) rename dali/kernels/signal/{wavelets => wavelet}/cwt_gpu.h (88%) rename dali/operators/signal/{wavelets => wavelet}/CMakeLists.txt (100%) rename dali/operators/signal/{wavelets => wavelet}/cwt_op.h (100%) rename dali/operators/signal/{wavelets => wavelet}/cwt_op_gpu.cu (100%) diff --git a/dali/kernels/signal/CMakeLists.txt b/dali/kernels/signal/CMakeLists.txt index 07b62d5342..74ca2e8970 100644 --- a/dali/kernels/signal/CMakeLists.txt +++ b/dali/kernels/signal/CMakeLists.txt @@ -17,7 +17,7 @@ add_subdirectory(decibel) if (BUILD_FFTS) add_subdirectory(fft) endif() -add_subdirectory(wavelets) +add_subdirectory(wavelet) add_subdirectory(window) collect_headers(DALI_INST_HDRS PARENT_SCOPE) diff --git a/dali/kernels/signal/wavelets/CMakeLists.txt b/dali/kernels/signal/wavelet/CMakeLists.txt similarity index 100% rename from dali/kernels/signal/wavelets/CMakeLists.txt rename to dali/kernels/signal/wavelet/CMakeLists.txt diff --git a/dali/kernels/signal/wavelets/cwt_args.h b/dali/kernels/signal/wavelet/cwt_args.h similarity index 80% rename from dali/kernels/signal/wavelets/cwt_args.h rename to dali/kernels/signal/wavelet/cwt_args.h index 14f6cb7d5b..b61d064a9e 100644 --- a/dali/kernels/signal/wavelets/cwt_args.h +++ b/dali/kernels/signal/wavelet/cwt_args.h @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef DALI_KERNELS_SIGNAL_WAVELETS_CWT_ARGS_H_ -#define DALI_KERNELS_SIGNAL_WAVELETS_CWT_ARGS_H_ +#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ +#define DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ namespace dali { namespace kernels { namespace signal { -namespace wavelets { +namespace wavelet { template struct CwtArgs { T a; }; -} // namespace wavelets +} // namespace wavelet } // namespace signal } // namespace kernels } // namespace dali -#endif // DALI_KERNELS_SIGNAL_WAVELETS_CWT_ARGS_H_ +#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ diff --git a/dali/kernels/signal/wavelets/cwt_gpu.cu b/dali/kernels/signal/wavelet/cwt_gpu.cu similarity index 98% rename from dali/kernels/signal/wavelets/cwt_gpu.cu rename to dali/kernels/signal/wavelet/cwt_gpu.cu index be2b19bde2..a15f82929a 100644 --- a/dali/kernels/signal/wavelets/cwt_gpu.cu +++ b/dali/kernels/signal/wavelet/cwt_gpu.cu @@ -25,7 +25,7 @@ namespace dali { namespace kernels { namespace signal { -namespace wavelets { +namespace wavelet { template struct SampleDesc { @@ -92,7 +92,7 @@ void CwtGpu::Run(KernelContext &context, const OutListGPU; template class CwtGpu; -} // namespace wavelets +} // namespace wavelet } // namespace signal } // namespace kernels } // namespace dali diff --git a/dali/kernels/signal/wavelets/cwt_gpu.h b/dali/kernels/signal/wavelet/cwt_gpu.h similarity index 88% rename from dali/kernels/signal/wavelets/cwt_gpu.h rename to dali/kernels/signal/wavelet/cwt_gpu.h index 2fa4fd939e..62f9cef738 100644 --- a/dali/kernels/signal/wavelets/cwt_gpu.h +++ b/dali/kernels/signal/wavelet/cwt_gpu.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef DALI_KERNELS_SIGNAL_WAVELETS_CWT_GPU_H_ -#define DALI_KERNELS_SIGNAL_WAVELETS_CWT_GPU_H_ +#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ +#define DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ #include #include "dali/core/common.h" @@ -26,7 +26,7 @@ namespace dali { namespace kernels { namespace signal { -namespace wavelets { +namespace wavelet { template class DLL_PUBLIC CwtGpu { @@ -42,9 +42,9 @@ class DLL_PUBLIC CwtGpu { const InListGPU &in, const CwtArgs &args); }; -} // namespace wavelets +} // namespace wavelet } // namespace signal } // namespace kernels } // namespace dali -#endif // DALI_KERNELS_SIGNAL_WAVELETS_CWT_GPU_H_ +#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ diff --git a/dali/operators/signal/CMakeLists.txt b/dali/operators/signal/CMakeLists.txt index c16a5d4687..44d93c05ba 100644 --- a/dali/operators/signal/CMakeLists.txt +++ b/dali/operators/signal/CMakeLists.txt @@ -16,7 +16,7 @@ add_subdirectory(decibel) if (BUILD_FFTS) add_subdirectory(fft) endif() -add_subdirectory(wavelets) +add_subdirectory(wavelet) collect_headers(DALI_INST_HDRS PARENT_SCOPE) collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) diff --git a/dali/operators/signal/wavelets/CMakeLists.txt b/dali/operators/signal/wavelet/CMakeLists.txt similarity index 100% rename from dali/operators/signal/wavelets/CMakeLists.txt rename to dali/operators/signal/wavelet/CMakeLists.txt diff --git a/dali/operators/signal/wavelets/cwt_op.h b/dali/operators/signal/wavelet/cwt_op.h similarity index 100% rename from dali/operators/signal/wavelets/cwt_op.h rename to dali/operators/signal/wavelet/cwt_op.h diff --git a/dali/operators/signal/wavelets/cwt_op_gpu.cu b/dali/operators/signal/wavelet/cwt_op_gpu.cu similarity index 100% rename from dali/operators/signal/wavelets/cwt_op_gpu.cu rename to dali/operators/signal/wavelet/cwt_op_gpu.cu From 5eed0c564a24847f32a697b7f06ce2200833945e Mon Sep 17 00:00:00 2001 From: JakubO Date: Mon, 22 May 2023 18:46:05 +0200 Subject: [PATCH 04/14] add WaveletArgs class --- dali/kernels/signal/wavelet/wavelet_args.h | 51 ++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 dali/kernels/signal/wavelet/wavelet_args.h diff --git a/dali/kernels/signal/wavelet/wavelet_args.h b/dali/kernels/signal/wavelet/wavelet_args.h new file mode 100644 index 0000000000..81a5d20a50 --- /dev/null +++ b/dali/kernels/signal/wavelet/wavelet_args.h @@ -0,0 +1,51 @@ +// Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_WAVELET_ARGS_H_ +#define DALI_KERNELS_SIGNAL_WAVELET_ARGS_H_ + +#include +#include +#include "dali/kernels/signal/wavelet/mother_wavelet.cuh" + +namespace dali { +namespace kernels { +namespace signal { + +template +struct WaveletArgs { + // mother wavelet name + WaveletName wavelet = WaveletName::HAAR; + + // wavelet shift parameter + T b = 0.0; + + // lower limit of wavelet samples + T begin = -1.0; + + // upper limit of wavelet samples + T end = 1.0; + + // wavelet sampling rate (samples/s) + T sampling_rate = 1000.0; +}; + +template class WaveletArgs; +template class WaveletArgs; + +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_WAVELET_ARGS_H_ From 279e61b69549150e3e9a148fbc387dbc5b0e58e0 Mon Sep 17 00:00:00 2001 From: JakubO Date: Mon, 5 Jun 2023 16:32:08 +0200 Subject: [PATCH 05/14] Improve wavelet computing kernel This change was mainly about moving from storing wavelets as functions to functors. Now wavelets can have extra parameters. This introduced a challenge of making the CUDA kernel accept these functors so templates were used. A helper utility was also introduced on operator side. RunForName function translates wavelet names and runs the right DALI kernel. --- dali/kernels/signal/wavelet/mother_wavelet.cu | 189 ++++++++++-------- .../kernels/signal/wavelet/mother_wavelet.cuh | 146 ++++++++++++-- dali/kernels/signal/wavelet/wavelet_args.h | 51 ----- dali/kernels/signal/wavelet/wavelet_gpu.cu | 138 ++++++++----- dali/kernels/signal/wavelet/wavelet_gpu.cuh | 63 +++++- dali/operators/signal/wavelet/wavelet_run.h | 99 +++++++++ 6 files changed, 477 insertions(+), 209 deletions(-) delete mode 100644 dali/kernels/signal/wavelet/wavelet_args.h create mode 100644 dali/operators/signal/wavelet/wavelet_run.h diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cu b/dali/kernels/signal/wavelet/mother_wavelet.cu index 664d7aab69..7b3afe1eef 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cu +++ b/dali/kernels/signal/wavelet/mother_wavelet.cu @@ -12,17 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include #include "dali/kernels/signal/wavelet/mother_wavelet.cuh" +#include "dali/core/math_util.h" namespace dali { namespace kernels { namespace signal { template -__device__ -T HaarWavelet(T t, T a, T b) { - T x = std::pow(2.0, a) - b; +HaarWavelet::HaarWavelet(const std::vector &args) { + if (args.size() != 0) { + throw new std::invalid_argument("HaarWavelet doesn't accept any arguments."); + } +} + +template +__device__ T HaarWavelet::operator()(const T &t, const T &a, const T &b) const { + T x = std::pow(2.0, a) * t - b; if (0.0 <= x && x < 0.5) { return std::pow(2.0, a / 2.0); } @@ -32,138 +39,148 @@ T HaarWavelet(T t, T a, T b) { return 0.0; } -template -__device__ -T DaubechiesWavelet(T t, T a, T b) { - return 0.0; -} +template class HaarWavelet; +template class HaarWavelet; template -__device__ -T SymletWavelet(T t, T a, T b) { - return 0.0; +DaubechiesWavelet::DaubechiesWavelet(const std::vector &args) { + } template -__device__ -T CoifletWavelet(T t, T a, T b) { +__device__ T DaubechiesWavelet::operator()(const T &t, const T &a, const T &b) const { return 0.0; } +template class DaubechiesWavelet; +template class DaubechiesWavelet; + template -__device__ -T BiorthogonalWavelet(T t, T a, T b) { - return 0.0; +SymletWavelet::SymletWavelet(const std::vector &args) { + } template -__device__ -T MeyerWavelet(T t, T a, T b) { +__device__ T SymletWavelet::operator()(const T &t, const T &a, const T &b) const { return 0.0; } +template class SymletWavelet; +template class SymletWavelet; + template -__device__ -T GaussianWavelet(T t, T a, T b) { - return 0.0; +CoifletWavelet::CoifletWavelet(const std::vector &args) { + } template -__device__ -T MexicanHatWavelet(T t, T a, T b) { +__device__ T CoifletWavelet::operator()(const T &t, const T &a, const T &b) const { return 0.0; } +template class CoifletWavelet; +template class CoifletWavelet; + template -__device__ -T MorletWavelet(T t, T a, T b) { - return 0.0; +MeyerWavelet::MeyerWavelet(const std::vector &args) { + if (args.size() != 0) { + throw new std::invalid_argument("MeyerWavelet doesn't accept any arguments."); + } } template -__device__ -T ComplexGaussianWavelet(T t, T a, T b) { - return 0.0; +__device__ T MeyerWavelet::operator()(const T &t, const T &a, const T &b) const { + T x = std::pow(2.0, a) * t - b - 0.5; + T psi1 = (4/(3*M_PI)*x*std::cos((2*M_PI)/3*x)-1/M_PI*std::sin((4*M_PI)/3*x))/(x-16/9*std::pow(x, 3.0)); + T psi2 = (8/(3*M_PI)*x*std::cos(8*M_PI/3*x)+1/M_PI*std::sin((4*M_PI)/3)*x)/(x-64/9*std::pow(x, 3.0)); + return std::pow(2.0, a / 2.0) * (psi1 + psi2); } +template class MeyerWavelet; +template class MeyerWavelet; + template -__device__ -T ShannonWavelet(T t, T a, T b) { - return 0.0; +GaussianWavelet::GaussianWavelet(const std::vector &args) { + } template -__device__ -T FbspWavelet(T t, T a, T b) { +__device__ T GaussianWavelet::operator()(const T &t, const T &a, const T &b) const { return 0.0; } +template class GaussianWavelet; +template class GaussianWavelet; + template -__device__ -T ComplexMorletWavelet(T t, T a, T b) { - return 0.0; +MexicanHatWavelet::MexicanHatWavelet(const std::vector &args) { + if (args.size() != 1) { + throw new std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma."); + } + this->sigma = *args.begin(); } template -MotherWavelet::MotherWavelet(const WaveletName& name) { - switch(name) { - case WaveletName::HAAR: - waveletFunc = &HaarWavelet; - break; - - case WaveletName::DB: - waveletFunc = &DaubechiesWavelet; - break; - - case WaveletName::SYM: - waveletFunc = &SymletWavelet; - break; - - case WaveletName::COIF: - waveletFunc = &CoifletWavelet; - break; - - case WaveletName::BIOR: - waveletFunc = &BiorthogonalWavelet; - break; +__device__ T MexicanHatWavelet::operator()(const T &t, const T &a, const T &b) const { + T x = std::pow(2.0, a) * t - b; + return std::pow(2.0, a / 2.0) * (2/(std::sqrt(3*sigma)*std::pow(M_PI, 0.25))*(1-std::pow(x/sigma, 2.0))*std::exp(-std::pow(x, 2.0)/(2*std::pow(sigma, 2.0)))); +} - case WaveletName::MEY: - waveletFunc = &MeyerWavelet; - break; +template class MexicanHatWavelet; +template class MexicanHatWavelet; - case WaveletName::GAUS: - waveletFunc = &GaussianWavelet; - break; +template +MorletWavelet::MorletWavelet(const std::vector &args) { + if (args.size() != 1) { + throw new std::invalid_argument("MorletWavelet accepts exactly 1 argument - C."); + } + this->C = *args.begin(); +} - case WaveletName::MEXH: - waveletFunc = &MexicanHatWavelet; - break; +template +__device__ T MorletWavelet::operator()(const T &t, const T &a, const T &b) const { + T x = std::pow(2.0, a) * t - b; + return std::pow(2.0, a / 2.0) * (C * std::exp(-std::pow(x, 2.0)) * std::cos(5 * x)); +} - case WaveletName::MORL: - waveletFunc = &MorletWavelet; - break; +template class MorletWavelet; +template class MorletWavelet; - case WaveletName::CGAU: - waveletFunc = &ComplexGaussianWavelet; - break; +template +ShannonWavelet::ShannonWavelet(const std::vector &args) { + if (args.size() != 0) { + throw new std::invalid_argument("ShannonWavelet doesn't accept any arguments."); + } +} - case WaveletName::SHAN: - waveletFunc = &ShannonWavelet; - break; +template +__device__ T ShannonWavelet::operator()(const T &t, const T &a, const T &b) const { + T x = std::pow(2.0, a) * t - b; + return std::pow(2.0, a / 2.0) * (sinc(x - 0.5) - 2 * sinc(2 * x - 1)); +} - case WaveletName::FBSP: - waveletFunc = &FbspWavelet; - break; +template class ShannonWavelet; +template class ShannonWavelet; - case WaveletName::CMOR: - waveletFunc = &ComplexMorletWavelet; - break; - - default: - throw new std::invalid_argument("Unknown wavelet name."); +template +FbspWavelet::FbspWavelet(const std::vector &args) { + if (args.size() != 0) { + throw new std::invalid_argument("FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order."); } + this->m = *args.begin(); + this->fb = *(args.begin()+1); + this->fc = *(args.begin()+2); } +template +__device__ T FbspWavelet::operator()(const T &t, const T &a, const T &b) const { + T x = std::pow(2.0, a) * t - b; + return std::pow(2.0, a / 2.0) * (std::sqrt(fb)*std::pow(sinc(x/std::pow(fb, m)), m)*std::exp(2*M_PI*fc*x)); +} + +template class FbspWavelet; +template class FbspWavelet; + } // namespace signal } // namespace kernel } // namespace dali diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cuh b/dali/kernels/signal/wavelet/mother_wavelet.cuh index 76eddd0278..045e4f5065 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cuh +++ b/dali/kernels/signal/wavelet/mother_wavelet.cuh @@ -21,40 +21,146 @@ #include "dali/core/util.h" #include "dali/kernels/kernel.h" +#include + namespace dali { namespace kernels { namespace signal { -enum class WaveletName { - HAAR, - DB, - SYM, - COIF, - BIOR, - MEY, - GAUS, - MEXH, - MORL, - CGAU, - SHAN, - FBSP, - CMOR +// wavelets are represented by functors +// they can store any necessary parameters +// they must overload () operator + +template +class HaarWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + HaarWavelet() = default; + HaarWavelet(const std::vector &args); + ~HaarWavelet() = default; + + __device__ T operator()(const T &t, const T &a, const T &b) const; +}; + +template +class DaubechiesWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + DaubechiesWavelet() = default; + DaubechiesWavelet(const std::vector &args); + ~DaubechiesWavelet() = default; + + __device__ T operator()(const T &t, const T &a, const T &b) const; +}; + +template +class SymletWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + SymletWavelet() = default; + SymletWavelet(const std::vector &args); + ~SymletWavelet() = default; + + __device__ T operator()(const T &t, const T &a, const T &b) const; +}; + +template +class CoifletWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + CoifletWavelet() = default; + CoifletWavelet(const std::vector &args); + ~CoifletWavelet() = default; + + __device__ T operator()(const T &t, const T &a, const T &b) const; +}; + +template +class MeyerWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + MeyerWavelet() = default; + MeyerWavelet(const std::vector &args); + ~MeyerWavelet() = default; + + __device__ T operator()(const T &t, const T &a, const T &b) const; +}; + +template +class GaussianWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + GaussianWavelet() = default; + GaussianWavelet(const std::vector &args); + ~GaussianWavelet() = default; + + __device__ T operator()(const T &t, const T &a, const T &b) const; +}; + +template +class MexicanHatWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + MexicanHatWavelet() = default; + MexicanHatWavelet(const std::vector &args); + ~MexicanHatWavelet() = default; + + __device__ T operator()(const T &t, const T &a, const T &b) const; + + private: + T sigma; }; template -class MotherWavelet { +class MorletWavelet { static_assert(std::is_floating_point::value, "Data type should be floating point"); + public: + MorletWavelet() = default; + MorletWavelet(const std::vector &args); + ~MorletWavelet() = default; + + __device__ T operator()(const T &t, const T &a, const T &b) const; + private: + T C; +}; + +template +class ShannonWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); public: - MotherWavelet(const WaveletName &name); - ~MotherWavelet() = default; + ShannonWavelet() = default; + ShannonWavelet(const std::vector &args); + ~ShannonWavelet() = default; - __device__ T (*waveletFunc)(T t, T a, T b); + __device__ T operator()(const T &t, const T &a, const T &b) const; }; -template class MotherWavelet; -template class MotherWavelet; +template +class FbspWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + FbspWavelet() = default; + FbspWavelet(const std::vector &args); + ~FbspWavelet() = default; + + __device__ T operator()(const T &t, const T &a, const T &b) const; + + private: + T m; + T fb; + T fc; +}; } // namespace signal } // namespace kernel diff --git a/dali/kernels/signal/wavelet/wavelet_args.h b/dali/kernels/signal/wavelet/wavelet_args.h deleted file mode 100644 index 81a5d20a50..0000000000 --- a/dali/kernels/signal/wavelet/wavelet_args.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef DALI_KERNELS_SIGNAL_WAVELET_ARGS_H_ -#define DALI_KERNELS_SIGNAL_WAVELET_ARGS_H_ - -#include -#include -#include "dali/kernels/signal/wavelet/mother_wavelet.cuh" - -namespace dali { -namespace kernels { -namespace signal { - -template -struct WaveletArgs { - // mother wavelet name - WaveletName wavelet = WaveletName::HAAR; - - // wavelet shift parameter - T b = 0.0; - - // lower limit of wavelet samples - T begin = -1.0; - - // upper limit of wavelet samples - T end = 1.0; - - // wavelet sampling rate (samples/s) - T sampling_rate = 1000.0; -}; - -template class WaveletArgs; -template class WaveletArgs; - -} // namespace signal -} // namespace kernels -} // namespace dali - -#endif // DALI_KERNELS_SIGNAL_WAVELET_ARGS_H_ diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cu b/dali/kernels/signal/wavelet/wavelet_gpu.cu index 9ebf4512c2..457e14c27d 100644 --- a/dali/kernels/signal/wavelet/wavelet_gpu.cu +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cu @@ -21,69 +21,117 @@ #include "dali/core/format.h" #include "dali/kernels/kernel.h" #include "dali/kernels/signal/wavelet/mother_wavelet.cuh" +#include "dali/core/tensor_shape.h" namespace dali { namespace kernels { namespace signal { -template -struct SampleDesc { - const T *a = nullptr; - int64_t size_a = 0; - T *out = nullptr; - int64_t size_out = 0; -}; - -template -__global__ void ComputeWavelet(const SampleDesc* sample_data, - T begin, T sampling_rate, T b, MotherWavelet wavelet) { +template class W > +__global__ void ComputeWavelet(const SampleDesc* sample_data, W wavelet) { + auto& sample = sample_data[blockIdx.z]; + auto a = sample.a[blockIdx.y]; const int64_t block_size = blockDim.x * blockDim.y; - const int64_t tid = threadIdx.y * blockDim.x + threadIdx.x; - const T t = begin + (T)tid / sampling_rate; - const T a = sample_data->a[blockIdx.x]; - sample_data->out[tid + blockIdx.x * block_size] = wavelet.waveletFunc(t, a, b); + const int64_t t_id = block_size * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x; + if (t_id >= sample.size_in) return; + const T t = sample.span.begin + (T)t_id / sample.span.sampling_rate; + for (int i = 0; i < sample.size_b; ++i) { + const int64_t out_id = blockIdx.y * sample.size_b * sample.size_in + i * sample.size_in + t_id; + auto b = sample.b[i]; + sample.out[out_id] = wavelet(t, a, b); + } } -template -WaveletGpu::~WaveletGpu() = default; - -template -KernelRequirements WaveletGpu::Setup(KernelContext &context, - const WaveletArgs &args) { - ScratchpadEstimator se; - se.add>(1); - se.add>(1); +template class W > +DLL_PUBLIC KernelRequirements WaveletGpu::Setup(KernelContext &context, + const InListGPU &a, + const InListGPU &b, + const WaveletSpan &span, + const std::vector &args) { + ENFORCE_SHAPES(a.shape, b.shape); + auto out_shape = this->GetOutputShape(a.shape, b.shape, span); KernelRequirements req; - req.scratch_sizes = se.sizes; + req.output_shapes = {out_shape}; + wavelet_ = W(args); return req; } -template -void WaveletGpu::Run(KernelContext &context, - const OutListGPU &out, - const InListGPU &a, - const WaveletArgs &args) { - auto* sample_data = context.scratchpad->AllocateHost>(1); +template class W > +DLL_PUBLIC void WaveletGpu::Run(KernelContext &ctx, + OutListGPU &out, + const InListGPU &a, + const InListGPU &b, + const WaveletSpan &span) { + ENFORCE_SHAPES(a.shape, b.shape); + + auto num_samples = a.num_samples(); + //std::vector> sample_data = std::vector>(num_samples); + auto* sample_data = ctx.scratchpad->AllocateHost>(num_samples); + int64_t max_size_in = 0, max_size_a = 0; - sample_data[0].out = out.tensor_data(0); - sample_data[0].a = a.tensor_data(0); - sample_data[0].size_a = volume(a.tensor_shape(0)); - auto in_size = (args.end - args.begin) * args.sampling_rate; - sample_data[0].size_out = in_size * sample_data[0].size_a; + for (int i = 0; i < num_samples; i++) { + auto &sample = sample_data[i]; + sample.out = out.tensor_data(i); + sample.a = a.tensor_data(i); + sample.size_a = a.shape.tensor_size(i); + max_size_a = std::max(max_size_a, sample.size_a); + sample.b = b.tensor_data(i); + sample.size_b = b.shape.tensor_size(i); + sample.span = span; + sample.size_in = std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate); + max_size_in = std::max(max_size_in, sample.size_in); + } - auto* sample_data_gpu = context.scratchpad->AllocateGPU>(1); + // auto sample_data_gpu = std::get<0>(ctx.scratchpad->ToContiguousGPU(ctx.gpu.stream, sample_data)); + auto* sample_data_gpu = ctx.scratchpad->AllocateGPU>(num_samples); CUDA_CALL( - cudaMemcpyAsync(sample_data_gpu, sample_data, sizeof(SampleDesc), - cudaMemcpyHostToDevice, context.gpu.stream)); + cudaMemcpyAsync(sample_data_gpu, sample_data, num_samples * sizeof(SampleDesc), + cudaMemcpyHostToDevice, ctx.gpu.stream)); + + dim3 block(32, 32); + const int64_t block_size = block.x * block.y; + dim3 grid((max_size_in + block_size - 1) / block_size, max_size_a, num_samples); + + ComputeWavelet<<>>(sample_data_gpu, wavelet_); +} - dim3 block(sample_data[0].size_a); - dim3 grid(in_size); - ComputeWavelet<<>>( - sample_data_gpu, args.begin, args.sampling_rate, args.b, MotherWavelet(args.wavelet)); +template class W > +TensorListShape<> WaveletGpu::GetOutputShape(const TensorListShape<> &a_shape, + const TensorListShape<> &b_shape, + const WaveletSpan &span) { + int N = a_shape.num_samples(); + int in_size = std::ceil((span.end - span.begin) * span.sampling_rate); + TensorListShape<> out_shape(N, 3); + TensorShape<> tshape; + for (int i = 0; i < N; i++) { + // output tensor will be 3-dimensional of shape: + // a coeffs x b coeffs x signal samples + tshape = TensorShape<>({a_shape.tensor_shape(i).num_elements(), b_shape.tensor_shape(i).num_elements(), in_size}); + out_shape.set_tensor_shape(i, tshape); + } + return out_shape; } -template class WaveletGpu; -template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; } // namespace signal } // namespace kernels diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cuh b/dali/kernels/signal/wavelet/wavelet_gpu.cuh index f25cbc9d16..3a9a532b88 100644 --- a/dali/kernels/signal/wavelet/wavelet_gpu.cuh +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cuh @@ -16,32 +16,81 @@ #define DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_ #include +#include #include "dali/core/common.h" #include "dali/core/error_handling.h" #include "dali/core/format.h" #include "dali/core/util.h" #include "dali/kernels/kernel.h" -#include "dali/kernels/signal/wavelet/wavelet_args.h" +#include "dali/kernels/signal/wavelet/mother_wavelet.cuh" + +// makes sure both tensors have the same number of samples and +// that they're one-dimensional +#define ENFORCE_SHAPES(a_shape, b_shape) do { \ +DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(),"a and b tensors must have the same amount of samples."); \ +for (int i = 0; i < a_shape.num_samples(); ++i) { \ + DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, "Tensor of a coeffs should be 1-dimensional."); \ + DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, "Tensor of b coeffs should be 1-dimensional."); \ +} \ +} while(0); namespace dali { namespace kernels { namespace signal { +// stores data needed to reconstruct wavelet input arguments template +struct WaveletSpan { + // lower limit of wavelet samples + T begin = -1.0; + + // upper limit of wavelet samples + T end = 1.0; + + // wavelet sampling rate (samples/s) + T sampling_rate = 1000.0; +}; + +template class WaveletSpan; +template class WaveletSpan; + +template +struct SampleDesc { + const T *a = nullptr; + int64_t size_a = 0; + const T *b = nullptr; + int64_t size_b = 0; + T *out = nullptr; + int64_t size_in = 0; + WaveletSpan span; +}; + +template class W > class DLL_PUBLIC WaveletGpu { public: static_assert(std::is_floating_point::value, "Only floating point types are supported"); - DLL_PUBLIC ~WaveletGpu(); + DLL_PUBLIC WaveletGpu() = default; + DLL_PUBLIC ~WaveletGpu() = default; DLL_PUBLIC KernelRequirements Setup(KernelContext &context, - const WaveletArgs &args); + const InListGPU &a, + const InListGPU &b, + const WaveletSpan &span, + const std::vector &args); + + DLL_PUBLIC void Run(KernelContext &ctx, + OutListGPU &out, + const InListGPU &a, + const InListGPU &b, + const WaveletSpan &span); - DLL_PUBLIC void Run(KernelContext &context, - const OutListGPU &out, - const InListGPU &a, - const WaveletArgs &args); + static TensorListShape<> GetOutputShape(const TensorListShape<> &a_shape, + const TensorListShape<> &b_shape, + const WaveletSpan &span); + private: + W wavelet_; }; } // namespace signal diff --git a/dali/operators/signal/wavelet/wavelet_run.h b/dali/operators/signal/wavelet/wavelet_run.h new file mode 100644 index 0000000000..2c6cbd2d7f --- /dev/null +++ b/dali/operators/signal/wavelet/wavelet_run.h @@ -0,0 +1,99 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_RUN_H_ +#define DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_RUN_H_ + +#include +#include "dali/core/format.h" +#include "dali/core/geom/mat.h" +#include "dali/core/static_switch.h" +#include "dali/kernels/kernel_manager.h" +#include "dali/pipeline/operator/operator.h" +#include "dali/kernels/signal/wavelet/mother_wavelet.cuh" +#include "dali/kernels/signal/wavelet/wavelet_gpu.cuh" + +namespace dali { + +// setups and runs kernel for specific wavelet type +template class W > +void RunWaveletKernel(kernels::KernelManager &kmgr, + size_t size, + size_t device, + kernels::KernelContext &ctx, + TensorListView &out, + TensorListView &a, + TensorListView &b, + const kernels::signal::WaveletSpan &span, + const std::vector &args) { + using Kernel = kernels::signal::WaveletGpu; + kmgr.template Resize(1); + kmgr.Setup(0, ctx, a, b, span, args); + kmgr.Run(0, ctx, out, a, b, span); +} + +// translates wavelet name to type and runs RunWaveletKernel() for that type +template +void RunForName(const std::string &name, + kernels::KernelManager &kmgr, + size_t size, + size_t device, + kernels::KernelContext &ctx, + TensorListView &out, + TensorListView &a, + TensorListView &b, + const kernels::signal::WaveletSpan &span, + const std::vector &args) { + if (name == "HAAR") { + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + } + else if (name == "DB") { + throw new std::logic_error("Not implemented."); + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + } + else if (name == "SYM") { + throw new std::logic_error("Not implemented."); + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + } + else if (name == "COIF") { + throw new std::logic_error("Not implemented."); + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + } + else if (name == "MEY") { + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + } + else if (name == "GAUS") { + throw new std::logic_error("Not implemented."); + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + } + else if (name == "MEXH") { + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + } + else if (name == "MORL") { + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + } + else if (name == "SHAN") { + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + } + else if (name == "FBSP") { + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + } + else { + throw new std::invalid_argument("Unknown wavelet name."); + } +} + +} // namespace dali + +#endif // DALI_OPERATORS_SIGNAL_WAVELET_RUN_H_ \ No newline at end of file From c4814f99896fa52764b9e65cf6b422208cbec8ba Mon Sep 17 00:00:00 2001 From: JakubO Date: Wed, 7 Jun 2023 03:59:09 +0200 Subject: [PATCH 06/14] Optimize and remove discrete wavelets Discrete wavelets have been discarded since we're currently focusing on continuous wavelet transform. Computation of wavelet input samples has been moved to a separate cuda kernel which should give a speedup when computing wavelets for multiple a and b parameters. Input wavelet samples, their scaled values and b coefficient are stored in shared memory instead of global memory which should speedup computation. --- dali/kernels/signal/wavelet/mother_wavelet.cu | 92 ++++--------------- .../kernels/signal/wavelet/mother_wavelet.cuh | 53 ++--------- dali/kernels/signal/wavelet/wavelet_gpu.cu | 64 ++++++++----- dali/kernels/signal/wavelet/wavelet_gpu.cuh | 3 +- dali/operators/signal/wavelet/wavelet_run.h | 16 ---- 5 files changed, 71 insertions(+), 157 deletions(-) diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cu b/dali/kernels/signal/wavelet/mother_wavelet.cu index 7b3afe1eef..79bc695a36 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cu +++ b/dali/kernels/signal/wavelet/mother_wavelet.cu @@ -28,13 +28,12 @@ HaarWavelet::HaarWavelet(const std::vector &args) { } template -__device__ T HaarWavelet::operator()(const T &t, const T &a, const T &b) const { - T x = std::pow(2.0, a) * t - b; - if (0.0 <= x && x < 0.5) { - return std::pow(2.0, a / 2.0); +__device__ T HaarWavelet::operator()(const T &t) const { + if (0.0 <= t && t < 0.5) { + return 1.0; } - if (0.5 <= x && x < 1.0) { - return -std::pow(2.0, a / 2.0); + if (0.5 <= t && t < 1.0) { + return -1.0; } return 0.0; } @@ -42,45 +41,6 @@ __device__ T HaarWavelet::operator()(const T &t, const T &a, const T &b) cons template class HaarWavelet; template class HaarWavelet; -template -DaubechiesWavelet::DaubechiesWavelet(const std::vector &args) { - -} - -template -__device__ T DaubechiesWavelet::operator()(const T &t, const T &a, const T &b) const { - return 0.0; -} - -template class DaubechiesWavelet; -template class DaubechiesWavelet; - -template -SymletWavelet::SymletWavelet(const std::vector &args) { - -} - -template -__device__ T SymletWavelet::operator()(const T &t, const T &a, const T &b) const { - return 0.0; -} - -template class SymletWavelet; -template class SymletWavelet; - -template -CoifletWavelet::CoifletWavelet(const std::vector &args) { - -} - -template -__device__ T CoifletWavelet::operator()(const T &t, const T &a, const T &b) const { - return 0.0; -} - -template class CoifletWavelet; -template class CoifletWavelet; - template MeyerWavelet::MeyerWavelet(const std::vector &args) { if (args.size() != 0) { @@ -89,29 +49,15 @@ MeyerWavelet::MeyerWavelet(const std::vector &args) { } template -__device__ T MeyerWavelet::operator()(const T &t, const T &a, const T &b) const { - T x = std::pow(2.0, a) * t - b - 0.5; - T psi1 = (4/(3*M_PI)*x*std::cos((2*M_PI)/3*x)-1/M_PI*std::sin((4*M_PI)/3*x))/(x-16/9*std::pow(x, 3.0)); - T psi2 = (8/(3*M_PI)*x*std::cos(8*M_PI/3*x)+1/M_PI*std::sin((4*M_PI)/3)*x)/(x-64/9*std::pow(x, 3.0)); - return std::pow(2.0, a / 2.0) * (psi1 + psi2); +__device__ T MeyerWavelet::operator()(const T &t) const { + T psi1 = (4/(3*M_PI)*t*std::cos((2*M_PI)/3*t)-1/M_PI*std::sin((4*M_PI)/3*t))/(t-16/9*std::pow(t, 3.0)); + T psi2 = (8/(3*M_PI)*t*std::cos(8*M_PI/3*t)+1/M_PI*std::sin((4*M_PI)/3)*t)/(t-64/9*std::pow(t, 3.0)); + return psi1 + psi2; } template class MeyerWavelet; template class MeyerWavelet; -template -GaussianWavelet::GaussianWavelet(const std::vector &args) { - -} - -template -__device__ T GaussianWavelet::operator()(const T &t, const T &a, const T &b) const { - return 0.0; -} - -template class GaussianWavelet; -template class GaussianWavelet; - template MexicanHatWavelet::MexicanHatWavelet(const std::vector &args) { if (args.size() != 1) { @@ -121,9 +67,8 @@ MexicanHatWavelet::MexicanHatWavelet(const std::vector &args) { } template -__device__ T MexicanHatWavelet::operator()(const T &t, const T &a, const T &b) const { - T x = std::pow(2.0, a) * t - b; - return std::pow(2.0, a / 2.0) * (2/(std::sqrt(3*sigma)*std::pow(M_PI, 0.25))*(1-std::pow(x/sigma, 2.0))*std::exp(-std::pow(x, 2.0)/(2*std::pow(sigma, 2.0)))); +__device__ T MexicanHatWavelet::operator()(const T &t) const { + return 2/(std::sqrt(3*sigma)*std::pow(M_PI, 0.25))*(1-std::pow(t/sigma, 2.0))*std::exp(-std::pow(t, 2.0)/(2*std::pow(sigma, 2.0))); } template class MexicanHatWavelet; @@ -138,9 +83,8 @@ MorletWavelet::MorletWavelet(const std::vector &args) { } template -__device__ T MorletWavelet::operator()(const T &t, const T &a, const T &b) const { - T x = std::pow(2.0, a) * t - b; - return std::pow(2.0, a / 2.0) * (C * std::exp(-std::pow(x, 2.0)) * std::cos(5 * x)); +__device__ T MorletWavelet::operator()(const T &t) const { + return C * std::exp(-std::pow(t, 2.0)) * std::cos(5 * t); } template class MorletWavelet; @@ -154,9 +98,8 @@ ShannonWavelet::ShannonWavelet(const std::vector &args) { } template -__device__ T ShannonWavelet::operator()(const T &t, const T &a, const T &b) const { - T x = std::pow(2.0, a) * t - b; - return std::pow(2.0, a / 2.0) * (sinc(x - 0.5) - 2 * sinc(2 * x - 1)); +__device__ T ShannonWavelet::operator()(const T &t) const { + return sinc(t - 0.5) - 2 * sinc(2 * t - 1); } template class ShannonWavelet; @@ -173,9 +116,8 @@ FbspWavelet::FbspWavelet(const std::vector &args) { } template -__device__ T FbspWavelet::operator()(const T &t, const T &a, const T &b) const { - T x = std::pow(2.0, a) * t - b; - return std::pow(2.0, a / 2.0) * (std::sqrt(fb)*std::pow(sinc(x/std::pow(fb, m)), m)*std::exp(2*M_PI*fc*x)); +__device__ T FbspWavelet::operator()(const T &t) const { + return std::sqrt(fb)*std::pow(sinc(t/std::pow(fb, m)), m)*std::exp(2*M_PI*fc*t); } template class FbspWavelet; diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cuh b/dali/kernels/signal/wavelet/mother_wavelet.cuh index 045e4f5065..52388b97e6 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cuh +++ b/dali/kernels/signal/wavelet/mother_wavelet.cuh @@ -40,43 +40,7 @@ class HaarWavelet { HaarWavelet(const std::vector &args); ~HaarWavelet() = default; - __device__ T operator()(const T &t, const T &a, const T &b) const; -}; - -template -class DaubechiesWavelet { - static_assert(std::is_floating_point::value, - "Data type should be floating point"); - public: - DaubechiesWavelet() = default; - DaubechiesWavelet(const std::vector &args); - ~DaubechiesWavelet() = default; - - __device__ T operator()(const T &t, const T &a, const T &b) const; -}; - -template -class SymletWavelet { - static_assert(std::is_floating_point::value, - "Data type should be floating point"); - public: - SymletWavelet() = default; - SymletWavelet(const std::vector &args); - ~SymletWavelet() = default; - - __device__ T operator()(const T &t, const T &a, const T &b) const; -}; - -template -class CoifletWavelet { - static_assert(std::is_floating_point::value, - "Data type should be floating point"); - public: - CoifletWavelet() = default; - CoifletWavelet(const std::vector &args); - ~CoifletWavelet() = default; - - __device__ T operator()(const T &t, const T &a, const T &b) const; + __device__ T operator()(const T &t) const; }; template @@ -88,7 +52,7 @@ class MeyerWavelet { MeyerWavelet(const std::vector &args); ~MeyerWavelet() = default; - __device__ T operator()(const T &t, const T &a, const T &b) const; + __device__ T operator()(const T &t) const; }; template @@ -100,7 +64,10 @@ class GaussianWavelet { GaussianWavelet(const std::vector &args); ~GaussianWavelet() = default; - __device__ T operator()(const T &t, const T &a, const T &b) const; + __device__ T operator()(const T &t) const; + + private: + uint8_t N; }; template @@ -112,7 +79,7 @@ class MexicanHatWavelet { MexicanHatWavelet(const std::vector &args); ~MexicanHatWavelet() = default; - __device__ T operator()(const T &t, const T &a, const T &b) const; + __device__ T operator()(const T &t) const; private: T sigma; @@ -127,7 +94,7 @@ class MorletWavelet { MorletWavelet(const std::vector &args); ~MorletWavelet() = default; - __device__ T operator()(const T &t, const T &a, const T &b) const; + __device__ T operator()(const T &t) const; private: T C; @@ -142,7 +109,7 @@ class ShannonWavelet { ShannonWavelet(const std::vector &args); ~ShannonWavelet() = default; - __device__ T operator()(const T &t, const T &a, const T &b) const; + __device__ T operator()(const T &t) const; }; template @@ -154,7 +121,7 @@ class FbspWavelet { FbspWavelet(const std::vector &args); ~FbspWavelet() = default; - __device__ T operator()(const T &t, const T &a, const T &b) const; + __device__ T operator()(const T &t) const; private: T m; diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cu b/dali/kernels/signal/wavelet/wavelet_gpu.cu index 457e14c27d..d8b07667e0 100644 --- a/dali/kernels/signal/wavelet/wavelet_gpu.cu +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cu @@ -27,21 +27,50 @@ namespace dali { namespace kernels { namespace signal { +// computes wavelet value for each sample in specified range, +// and each a and b coeff template class W > __global__ void ComputeWavelet(const SampleDesc* sample_data, W wavelet) { + // id inside block + const int64_t b_id = threadIdx.y * blockDim.x + threadIdx.x; + // wavelet sample id + const int64_t t_id = blockDim.x * blockDim.y * blockIdx.x + b_id; auto& sample = sample_data[blockIdx.z]; - auto a = sample.a[blockIdx.y]; - const int64_t block_size = blockDim.x * blockDim.y; - const int64_t t_id = block_size * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x; if (t_id >= sample.size_in) return; - const T t = sample.span.begin + (T)t_id / sample.span.sampling_rate; + __shared__ T shm[1025]; + auto a = sample.a[blockIdx.y]; + auto x = std::pow(2.0, a); + if (a == 0.0) { + shm[b_id] = sample.in[t_id]; + } + else { + shm[b_id] = x * sample.in[t_id]; + shm[1024] = std::pow(2.0, a / 2.0); + } for (int i = 0; i < sample.size_b; ++i) { const int64_t out_id = blockIdx.y * sample.size_b * sample.size_in + i * sample.size_in + t_id; auto b = sample.b[i]; - sample.out[out_id] = wavelet(t, a, b); + if (b == 0.0) { + sample.out[out_id] = wavelet(shm[b_id]); + } + else { + sample.out[out_id] = wavelet(shm[b_id] - b); + } + if (a != 0.0) { + sample.out[out_id] *= shm[1024]; + } } } +// translate input range information to input samples +template +__global__ void ComputeInputSamples(const SampleDesc* sample_data) { + const int64_t t_id = blockDim.x * blockDim.y * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x; + auto& sample = sample_data[blockIdx.y]; + if (t_id >= sample.size_in) return; + sample.in[t_id] = sample.span.begin + (T)t_id / sample.span.sampling_rate; +} + template class W > DLL_PUBLIC KernelRequirements WaveletGpu::Setup(KernelContext &context, const InListGPU &a, @@ -65,8 +94,7 @@ DLL_PUBLIC void WaveletGpu::Run(KernelContext &ctx, ENFORCE_SHAPES(a.shape, b.shape); auto num_samples = a.num_samples(); - //std::vector> sample_data = std::vector>(num_samples); - auto* sample_data = ctx.scratchpad->AllocateHost>(num_samples); + std::vector> sample_data = std::vector>(num_samples); int64_t max_size_in = 0, max_size_a = 0; for (int i = 0; i < num_samples; i++) { @@ -79,20 +107,20 @@ DLL_PUBLIC void WaveletGpu::Run(KernelContext &ctx, sample.size_b = b.shape.tensor_size(i); sample.span = span; sample.size_in = std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate); + CUDA_CALL(cudaMalloc(&(sample.in), sizeof(T) * sample.size_in)); max_size_in = std::max(max_size_in, sample.size_in); } - // auto sample_data_gpu = std::get<0>(ctx.scratchpad->ToContiguousGPU(ctx.gpu.stream, sample_data)); - auto* sample_data_gpu = ctx.scratchpad->AllocateGPU>(num_samples); - CUDA_CALL( - cudaMemcpyAsync(sample_data_gpu, sample_data, num_samples * sizeof(SampleDesc), - cudaMemcpyHostToDevice, ctx.gpu.stream)); + auto* sample_data_gpu = std::get<0>(ctx.scratchpad->ToContiguousGPU(ctx.gpu.stream, sample_data)); dim3 block(32, 32); const int64_t block_size = block.x * block.y; - dim3 grid((max_size_in + block_size - 1) / block_size, max_size_a, num_samples); + dim3 grid1((max_size_in + block_size - 1) / block_size, num_samples); + dim3 grid2((max_size_in + block_size - 1) / block_size, max_size_a, num_samples); - ComputeWavelet<<>>(sample_data_gpu, wavelet_); + ComputeInputSamples<<>>(sample_data_gpu); + auto shared_mem_size = (block_size + 1) * sizeof(T); + ComputeWavelet<<>>(sample_data_gpu, wavelet_); } template class W > @@ -114,16 +142,8 @@ TensorListShape<> WaveletGpu::GetOutputShape(const TensorListShape<> &a_sh template class WaveletGpu; template class WaveletGpu; -template class WaveletGpu; -template class WaveletGpu; -template class WaveletGpu; -template class WaveletGpu; -template class WaveletGpu; -template class WaveletGpu; template class WaveletGpu; template class WaveletGpu; -template class WaveletGpu; -template class WaveletGpu; template class WaveletGpu; template class WaveletGpu; template class WaveletGpu; diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cuh b/dali/kernels/signal/wavelet/wavelet_gpu.cuh index 3a9a532b88..0026ff58c4 100644 --- a/dali/kernels/signal/wavelet/wavelet_gpu.cuh +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cuh @@ -60,8 +60,9 @@ struct SampleDesc { int64_t size_a = 0; const T *b = nullptr; int64_t size_b = 0; - T *out = nullptr; + T *in = nullptr; int64_t size_in = 0; + T *out = nullptr; WaveletSpan span; }; diff --git a/dali/operators/signal/wavelet/wavelet_run.h b/dali/operators/signal/wavelet/wavelet_run.h index 2c6cbd2d7f..93b2c1840d 100644 --- a/dali/operators/signal/wavelet/wavelet_run.h +++ b/dali/operators/signal/wavelet/wavelet_run.h @@ -58,25 +58,9 @@ void RunForName(const std::string &name, if (name == "HAAR") { RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); } - else if (name == "DB") { - throw new std::logic_error("Not implemented."); - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "SYM") { - throw new std::logic_error("Not implemented."); - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "COIF") { - throw new std::logic_error("Not implemented."); - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } else if (name == "MEY") { RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); } - else if (name == "GAUS") { - throw new std::logic_error("Not implemented."); - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } else if (name == "MEXH") { RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); } From d3a8d6a4f3694c763b3c4c3304ca2cbde7a1da0e Mon Sep 17 00:00:00 2001 From: JakubO Date: Sun, 11 Jun 2023 17:57:24 +0200 Subject: [PATCH 07/14] add DALIWaveletName enum --- dali/kernels/signal/wavelet/mother_wavelet.cu | 33 ++++++++------- .../kernels/signal/wavelet/mother_wavelet.cuh | 19 ++------- dali/kernels/signal/wavelet/wavelet_gpu.cu | 7 ++-- dali/operators/signal/wavelet/wavelet_name.h | 34 +++++++++++++++ dali/operators/signal/wavelet/wavelet_run.h | 29 ++++++------- dali/pipeline/data/types.h | 42 +++++++++++-------- dali/python/backend_impl.cc | 13 ++++++ dali/python/nvidia/dali/types.py | 3 +- 8 files changed, 115 insertions(+), 65 deletions(-) create mode 100644 dali/operators/signal/wavelet/wavelet_name.h diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cu b/dali/kernels/signal/wavelet/mother_wavelet.cu index 79bc695a36..2a2a95ef80 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cu +++ b/dali/kernels/signal/wavelet/mother_wavelet.cu @@ -50,8 +50,9 @@ MeyerWavelet::MeyerWavelet(const std::vector &args) { template __device__ T MeyerWavelet::operator()(const T &t) const { - T psi1 = (4/(3*M_PI)*t*std::cos((2*M_PI)/3*t)-1/M_PI*std::sin((4*M_PI)/3*t))/(t-16/9*std::pow(t, 3.0)); - T psi2 = (8/(3*M_PI)*t*std::cos(8*M_PI/3*t)+1/M_PI*std::sin((4*M_PI)/3)*t)/(t-64/9*std::pow(t, 3.0)); + T tt = t - 0.5; + T psi1 = (4.0/(3.0*M_PI)*tt*std::cos((2.0*M_PI)/3.0*tt)-1.0/M_PI*std::sin((4.0*M_PI)/3.0*tt))/(tt-16.0/9.0*std::pow(tt, 3.0)); + T psi2 = (8.0/(3.0*M_PI)*tt*std::cos(8.0*M_PI/3.0*tt)+1.0/M_PI*std::sin((4.0*M_PI)/3.0)*tt)/(tt-64.0/9.0*std::pow(tt, 3.0)); return psi1 + psi2; } @@ -63,12 +64,12 @@ MexicanHatWavelet::MexicanHatWavelet(const std::vector &args) { if (args.size() != 1) { throw new std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma."); } - this->sigma = *args.begin(); + this->sigma = args[0]; } template __device__ T MexicanHatWavelet::operator()(const T &t) const { - return 2/(std::sqrt(3*sigma)*std::pow(M_PI, 0.25))*(1-std::pow(t/sigma, 2.0))*std::exp(-std::pow(t, 2.0)/(2*std::pow(sigma, 2.0))); + return 2.0/(std::sqrt(3.0*sigma)*std::pow(M_PI, 0.25))*(1.0-std::pow(t/sigma, 2.0))*std::exp(-std::pow(t, 2.0)/(2.0*std::pow(sigma, 2.0))); } template class MexicanHatWavelet; @@ -79,12 +80,12 @@ MorletWavelet::MorletWavelet(const std::vector &args) { if (args.size() != 1) { throw new std::invalid_argument("MorletWavelet accepts exactly 1 argument - C."); } - this->C = *args.begin(); + this->C = args[0]; } template __device__ T MorletWavelet::operator()(const T &t) const { - return C * std::exp(-std::pow(t, 2.0)) * std::cos(5 * t); + return C * std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t); } template class MorletWavelet; @@ -92,14 +93,17 @@ template class MorletWavelet; template ShannonWavelet::ShannonWavelet(const std::vector &args) { - if (args.size() != 0) { - throw new std::invalid_argument("ShannonWavelet doesn't accept any arguments."); + if (args.size() != 2) { + throw new std::invalid_argument("ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order."); } + this->fb = args[0]; + this->fc = args[1]; } template __device__ T ShannonWavelet::operator()(const T &t) const { - return sinc(t - 0.5) - 2 * sinc(2 * t - 1); + auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb); + return t == 0.0 ? res : res*std::sin(t*fb*(T)(M_PI))/(t*fb*(T)(M_PI)); } template class ShannonWavelet; @@ -107,17 +111,18 @@ template class ShannonWavelet; template FbspWavelet::FbspWavelet(const std::vector &args) { - if (args.size() != 0) { + if (args.size() != 3) { throw new std::invalid_argument("FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order."); } - this->m = *args.begin(); - this->fb = *(args.begin()+1); - this->fc = *(args.begin()+2); + this->m = args[0]; + this->fb = args[1]; + this->fc = args[2]; } template __device__ T FbspWavelet::operator()(const T &t) const { - return std::sqrt(fb)*std::pow(sinc(t/std::pow(fb, m)), m)*std::exp(2*M_PI*fc*t); + auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb); + return t == 0.0 ? res : res*std::pow(std::sin((T)(M_PI)*t*fb/m)/((T)(M_PI)*t*fb/m), m); } template class FbspWavelet; diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cuh b/dali/kernels/signal/wavelet/mother_wavelet.cuh index 52388b97e6..cbdd6da716 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cuh +++ b/dali/kernels/signal/wavelet/mother_wavelet.cuh @@ -55,21 +55,6 @@ class MeyerWavelet { __device__ T operator()(const T &t) const; }; -template -class GaussianWavelet { - static_assert(std::is_floating_point::value, - "Data type should be floating point"); - public: - GaussianWavelet() = default; - GaussianWavelet(const std::vector &args); - ~GaussianWavelet() = default; - - __device__ T operator()(const T &t) const; - - private: - uint8_t N; -}; - template class MexicanHatWavelet { static_assert(std::is_floating_point::value, @@ -110,6 +95,10 @@ class ShannonWavelet { ~ShannonWavelet() = default; __device__ T operator()(const T &t) const; + + private: + T fb; + T fc; }; template diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cu b/dali/kernels/signal/wavelet/wavelet_gpu.cu index d8b07667e0..9829ef1ca8 100644 --- a/dali/kernels/signal/wavelet/wavelet_gpu.cu +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cu @@ -47,6 +47,7 @@ __global__ void ComputeWavelet(const SampleDesc* sample_data, W wavelet) { shm[b_id] = x * sample.in[t_id]; shm[1024] = std::pow(2.0, a / 2.0); } + __syncthreads(); for (int i = 0; i < sample.size_b; ++i) { const int64_t out_id = blockIdx.y * sample.size_b * sample.size_in + i * sample.size_in + t_id; auto b = sample.b[i]; @@ -106,8 +107,8 @@ DLL_PUBLIC void WaveletGpu::Run(KernelContext &ctx, sample.b = b.tensor_data(i); sample.size_b = b.shape.tensor_size(i); sample.span = span; - sample.size_in = std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate); - CUDA_CALL(cudaMalloc(&(sample.in), sizeof(T) * sample.size_in)); + sample.size_in = std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate) + 1; + sample.in = ctx.scratchpad->AllocateGPU(sample.size_in); max_size_in = std::max(max_size_in, sample.size_in); } @@ -128,7 +129,7 @@ TensorListShape<> WaveletGpu::GetOutputShape(const TensorListShape<> &a_sh const TensorListShape<> &b_shape, const WaveletSpan &span) { int N = a_shape.num_samples(); - int in_size = std::ceil((span.end - span.begin) * span.sampling_rate); + int in_size = std::ceil((span.end - span.begin) * span.sampling_rate) + 1; TensorListShape<> out_shape(N, 3); TensorShape<> tshape; for (int i = 0; i < N; i++) { diff --git a/dali/operators/signal/wavelet/wavelet_name.h b/dali/operators/signal/wavelet/wavelet_name.h new file mode 100644 index 0000000000..c9c94dff06 --- /dev/null +++ b/dali/operators/signal/wavelet/wavelet_name.h @@ -0,0 +1,34 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_NAME_H_ +#define DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_NAME_H_ + +namespace dali { + +/** + * @brief Supported wavelet names + */ +enum DALIWaveletName { + DALI_HAAR = 0, + DALI_MEY = 1, + DALI_MEXH = 2, + DALI_MORL = 3, + DALI_SHAN = 4, + DALI_FBSP = 5 +}; + +} // namespace dali + +#endif // DALI_OPERATORS_SIGNAL_WAVELET_NAME_H_ \ No newline at end of file diff --git a/dali/operators/signal/wavelet/wavelet_run.h b/dali/operators/signal/wavelet/wavelet_run.h index 93b2c1840d..7ce6b3a5f4 100644 --- a/dali/operators/signal/wavelet/wavelet_run.h +++ b/dali/operators/signal/wavelet/wavelet_run.h @@ -45,7 +45,7 @@ void RunWaveletKernel(kernels::KernelManager &kmgr, // translates wavelet name to type and runs RunWaveletKernel() for that type template -void RunForName(const std::string &name, +void RunForName(const DALIWaveletName &name, kernels::KernelManager &kmgr, size_t size, size_t device, @@ -55,25 +55,26 @@ void RunForName(const std::string &name, TensorListView &b, const kernels::signal::WaveletSpan &span, const std::vector &args) { - if (name == "HAAR") { + switch (name) { + case DALIWaveletName::DALI_HAAR: RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "MEY") { + break; + case DALIWaveletName::DALI_MEY: RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "MEXH") { + break; + case DALIWaveletName::DALI_MEXH: RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "MORL") { + break; + case DALIWaveletName::DALI_MORL: RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "SHAN") { + break; + case DALIWaveletName::DALI_SHAN: RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "FBSP") { + break; + case DALIWaveletName::DALI_FBSP: RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else { + break; + default: throw new std::invalid_argument("Unknown wavelet name."); } } diff --git a/dali/pipeline/data/types.h b/dali/pipeline/data/types.h index 0efa36e5a1..eed79432c6 100644 --- a/dali/pipeline/data/types.h +++ b/dali/pipeline/data/types.h @@ -30,6 +30,7 @@ #include "dali/core/float16.h" #include "dali/core/cuda_error.h" #include "dali/core/tensor_layout.h" +#include "dali/operators/signal/wavelet/wavelet_name.h" #ifdef DALI_BUILD_PROTO3 #include "dali/operators/reader/parser/tf_feature.h" @@ -123,6 +124,7 @@ enum DALIDataType : int { DALI_PYTHON_OBJECT = 24, DALI_TENSOR_LAYOUT_VEC = 25, DALI_DATA_TYPE_VEC = 26, + DALI_WAVELET_NAME = 27, DALI_DATATYPE_END = 1000 }; @@ -202,6 +204,9 @@ inline const char *GetBuiltinTypeName(DALIDataType t) { case DALI_INTERP_TYPE: return "DALIInterpType"; break; + case DALI_WAVELET_NAME: + return "DALIWaveletName"; + break; case DALI_TENSOR_LAYOUT: return "TensorLayout"; break; @@ -557,24 +562,25 @@ DLL_PUBLIC inline bool IsValidType(const TypeInfo &type) { DALI_REGISTER_TYPE_IMPL(Type, dtype); // Instantiate some basic types -DALI_REGISTER_TYPE(NoType, DALI_NO_TYPE); -DALI_REGISTER_TYPE(uint8_t, DALI_UINT8); -DALI_REGISTER_TYPE(uint16_t, DALI_UINT16); -DALI_REGISTER_TYPE(uint32_t, DALI_UINT32); -DALI_REGISTER_TYPE(uint64_t, DALI_UINT64); -DALI_REGISTER_TYPE(int8_t, DALI_INT8); -DALI_REGISTER_TYPE(int16_t, DALI_INT16); -DALI_REGISTER_TYPE(int32_t, DALI_INT32); -DALI_REGISTER_TYPE(int64_t, DALI_INT64); -DALI_REGISTER_TYPE(float16, DALI_FLOAT16); -DALI_REGISTER_TYPE(float, DALI_FLOAT); -DALI_REGISTER_TYPE(double, DALI_FLOAT64); -DALI_REGISTER_TYPE(bool, DALI_BOOL); -DALI_REGISTER_TYPE(string, DALI_STRING); -DALI_REGISTER_TYPE(DALIImageType, DALI_IMAGE_TYPE); -DALI_REGISTER_TYPE(DALIDataType, DALI_DATA_TYPE); -DALI_REGISTER_TYPE(DALIInterpType, DALI_INTERP_TYPE); -DALI_REGISTER_TYPE(TensorLayout, DALI_TENSOR_LAYOUT); +DALI_REGISTER_TYPE(NoType, DALI_NO_TYPE); +DALI_REGISTER_TYPE(uint8_t, DALI_UINT8); +DALI_REGISTER_TYPE(uint16_t, DALI_UINT16); +DALI_REGISTER_TYPE(uint32_t, DALI_UINT32); +DALI_REGISTER_TYPE(uint64_t, DALI_UINT64); +DALI_REGISTER_TYPE(int8_t, DALI_INT8); +DALI_REGISTER_TYPE(int16_t, DALI_INT16); +DALI_REGISTER_TYPE(int32_t, DALI_INT32); +DALI_REGISTER_TYPE(int64_t, DALI_INT64); +DALI_REGISTER_TYPE(float16, DALI_FLOAT16); +DALI_REGISTER_TYPE(float, DALI_FLOAT); +DALI_REGISTER_TYPE(double, DALI_FLOAT64); +DALI_REGISTER_TYPE(bool, DALI_BOOL); +DALI_REGISTER_TYPE(string, DALI_STRING); +DALI_REGISTER_TYPE(DALIImageType, DALI_IMAGE_TYPE); +DALI_REGISTER_TYPE(DALIDataType, DALI_DATA_TYPE); +DALI_REGISTER_TYPE(DALIInterpType, DALI_INTERP_TYPE); +DALI_REGISTER_TYPE(DALIWaveletName, DALI_WAVELET_NAME); +DALI_REGISTER_TYPE(TensorLayout, DALI_TENSOR_LAYOUT); #ifdef DALI_BUILD_PROTO3 diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index 262c2ee907..504be8def0 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -27,6 +27,7 @@ #include "dali/operators.h" #include "dali/kernels/kernel.h" #include "dali/operators/reader/parser/tfrecord_parser.h" +#include "dali/operators/signal/wavelet/wavelet_name.h" #include "dali/pipeline/data/copy_to_external.h" #include "dali/pipeline/data/dltensor.h" #include "dali/pipeline/data/tensor.h" @@ -1672,6 +1673,7 @@ PYBIND11_MODULE(backend_impl, m) { .value("IMAGE_TYPE", DALI_IMAGE_TYPE) .value("DATA_TYPE", DALI_DATA_TYPE) .value("INTERP_TYPE", DALI_INTERP_TYPE) + .value("WAVELET_NAME", DALI_WAVELET_NAME) .value("TENSOR_LAYOUT", DALI_TENSOR_LAYOUT) .value("PYTHON_OBJECT", DALI_PYTHON_OBJECT) .value("_TENSOR_LAYOUT_VEC", DALI_TENSOR_LAYOUT_VEC) @@ -1716,6 +1718,16 @@ PYBIND11_MODULE(backend_impl, m) { .value("INTERP_GAUSSIAN", DALI_INTERP_GAUSSIAN) .export_values(); + // DALIWaveletName + py::enum_(types_m, "DALIWaveletName", "Wavelet name\n") + .value("HAAR", DALI_HAAR) + .value("MEY", DALI_MEY) + .value("MEXH", DALI_MEXH) + .value("MORL", DALI_MORL) + .value("SHAN", DALI_SHAN) + .value("FBSP", DALI_FBSP) + .export_values(); + // Operator node py::class_(m, "OpNode") .def("instance_name", @@ -1998,6 +2010,7 @@ PYBIND11_MODULE(backend_impl, m) { DALI_OPSPEC_ADDARG(DALIDataType) DALI_OPSPEC_ADDARG(DALIImageType) DALI_OPSPEC_ADDARG(DALIInterpType) + DALI_OPSPEC_ADDARG(DALIWaveletName) #ifdef DALI_BUILD_PROTO3 DALI_OPSPEC_ADDARG(TFFeature) #endif diff --git a/dali/python/nvidia/dali/types.py b/dali/python/nvidia/dali/types.py index f4362fd224..405165847a 100644 --- a/dali/python/nvidia/dali/types.py +++ b/dali/python/nvidia/dali/types.py @@ -16,7 +16,7 @@ from enum import Enum, unique import re -from nvidia.dali.backend_impl.types import DALIDataType, DALIImageType, DALIInterpType +from nvidia.dali.backend_impl.types import DALIDataType, DALIImageType, DALIInterpType, DALIWaveletName # TODO: Handle forwarding imports from backend_impl from nvidia.dali.backend_impl.types import * # noqa: F401, F403 @@ -63,6 +63,7 @@ def _not_implemented(val): DALIDataType.DATA_TYPE: ("nvidia.dali.types.DALIDataType", lambda x: DALIDataType(int(x))), DALIDataType.INTERP_TYPE: ("nvidia.dali.types.DALIInterpType", lambda x: DALIInterpType(int(x))), + DALIDataType.WAVELET_NAME: ("nvidia.dali.types.DALIWaveletName", lambda x: DALIWaveletName(int(x))), DALIDataType.TENSOR_LAYOUT: (":ref:`layout str`", lambda x: str(x)), DALIDataType.PYTHON_OBJECT: ("object", lambda x: x), DALIDataType._TENSOR_LAYOUT_VEC: From 27cedd3642fb28abfa8f4d8b58d394590a368829 Mon Sep 17 00:00:00 2001 From: JakubO Date: Sun, 11 Jun 2023 18:47:04 +0200 Subject: [PATCH 08/14] fix linting errors --- dali/kernels/signal/wavelet/mother_wavelet.cu | 18 ++++++++----- .../kernels/signal/wavelet/mother_wavelet.cuh | 18 ++++++------- dali/kernels/signal/wavelet/wavelet_gpu.cu | 18 +++++++------ dali/kernels/signal/wavelet/wavelet_gpu.cuh | 25 +++++++++++-------- dali/operators/signal/wavelet/wavelet_name.h | 4 +-- dali/operators/signal/wavelet/wavelet_run.h | 22 ++++++++++------ dali/python/nvidia/dali/types.py | 6 +++-- 7 files changed, 66 insertions(+), 45 deletions(-) diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cu b/dali/kernels/signal/wavelet/mother_wavelet.cu index 2a2a95ef80..66d83936ea 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cu +++ b/dali/kernels/signal/wavelet/mother_wavelet.cu @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include "dali/kernels/signal/wavelet/mother_wavelet.cuh" #include "dali/core/math_util.h" @@ -51,8 +52,10 @@ MeyerWavelet::MeyerWavelet(const std::vector &args) { template __device__ T MeyerWavelet::operator()(const T &t) const { T tt = t - 0.5; - T psi1 = (4.0/(3.0*M_PI)*tt*std::cos((2.0*M_PI)/3.0*tt)-1.0/M_PI*std::sin((4.0*M_PI)/3.0*tt))/(tt-16.0/9.0*std::pow(tt, 3.0)); - T psi2 = (8.0/(3.0*M_PI)*tt*std::cos(8.0*M_PI/3.0*tt)+1.0/M_PI*std::sin((4.0*M_PI)/3.0)*tt)/(tt-64.0/9.0*std::pow(tt, 3.0)); + T psi1 = (4.0/(3.0*M_PI)*tt*std::cos((2.0*M_PI)/3.0*tt)-1.0/M_PI*std::sin((4.0*M_PI)/3.0*tt))/ + (tt-16.0/9.0*std::pow(tt, 3.0)); + T psi2 = (8.0/(3.0*M_PI)*tt*std::cos(8.0*M_PI/3.0*tt)+1.0/M_PI*std::sin((4.0*M_PI)/3.0)*tt)/ + (tt-64.0/9.0*std::pow(tt, 3.0)); return psi1 + psi2; } @@ -69,7 +72,8 @@ MexicanHatWavelet::MexicanHatWavelet(const std::vector &args) { template __device__ T MexicanHatWavelet::operator()(const T &t) const { - return 2.0/(std::sqrt(3.0*sigma)*std::pow(M_PI, 0.25))*(1.0-std::pow(t/sigma, 2.0))*std::exp(-std::pow(t, 2.0)/(2.0*std::pow(sigma, 2.0))); + return 2.0/(std::sqrt(3.0*sigma)*std::pow(M_PI, 0.25))*(1.0-std::pow(t/sigma, 2.0))* + std::exp(-std::pow(t, 2.0)/(2.0*std::pow(sigma, 2.0))); } template class MexicanHatWavelet; @@ -94,7 +98,8 @@ template class MorletWavelet; template ShannonWavelet::ShannonWavelet(const std::vector &args) { if (args.size() != 2) { - throw new std::invalid_argument("ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order."); + throw new std::invalid_argument( + "ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order."); } this->fb = args[0]; this->fc = args[1]; @@ -112,7 +117,8 @@ template class ShannonWavelet; template FbspWavelet::FbspWavelet(const std::vector &args) { if (args.size() != 3) { - throw new std::invalid_argument("FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order."); + throw new std::invalid_argument( + "FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order."); } this->m = args[0]; this->fb = args[1]; @@ -129,5 +135,5 @@ template class FbspWavelet; template class FbspWavelet; } // namespace signal -} // namespace kernel +} // namespace kernels } // namespace dali diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cuh b/dali/kernels/signal/wavelet/mother_wavelet.cuh index cbdd6da716..70ae5d8ca2 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cuh +++ b/dali/kernels/signal/wavelet/mother_wavelet.cuh @@ -15,14 +15,14 @@ #ifndef DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ #define DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ +#include + #include "dali/core/common.h" #include "dali/core/error_handling.h" #include "dali/core/format.h" #include "dali/core/util.h" #include "dali/kernels/kernel.h" -#include - namespace dali { namespace kernels { namespace signal { @@ -37,7 +37,7 @@ class HaarWavelet { "Data type should be floating point"); public: HaarWavelet() = default; - HaarWavelet(const std::vector &args); + explicit HaarWavelet(const std::vector &args); ~HaarWavelet() = default; __device__ T operator()(const T &t) const; @@ -49,7 +49,7 @@ class MeyerWavelet { "Data type should be floating point"); public: MeyerWavelet() = default; - MeyerWavelet(const std::vector &args); + explicit MeyerWavelet(const std::vector &args); ~MeyerWavelet() = default; __device__ T operator()(const T &t) const; @@ -61,7 +61,7 @@ class MexicanHatWavelet { "Data type should be floating point"); public: MexicanHatWavelet() = default; - MexicanHatWavelet(const std::vector &args); + explicit MexicanHatWavelet(const std::vector &args); ~MexicanHatWavelet() = default; __device__ T operator()(const T &t) const; @@ -76,7 +76,7 @@ class MorletWavelet { "Data type should be floating point"); public: MorletWavelet() = default; - MorletWavelet(const std::vector &args); + explicit MorletWavelet(const std::vector &args); ~MorletWavelet() = default; __device__ T operator()(const T &t) const; @@ -91,7 +91,7 @@ class ShannonWavelet { "Data type should be floating point"); public: ShannonWavelet() = default; - ShannonWavelet(const std::vector &args); + explicit ShannonWavelet(const std::vector &args); ~ShannonWavelet() = default; __device__ T operator()(const T &t) const; @@ -107,7 +107,7 @@ class FbspWavelet { "Data type should be floating point"); public: FbspWavelet() = default; - FbspWavelet(const std::vector &args); + explicit FbspWavelet(const std::vector &args); ~FbspWavelet() = default; __device__ T operator()(const T &t) const; @@ -119,7 +119,7 @@ class FbspWavelet { }; } // namespace signal -} // namespace kernel +} // namespace kernels } // namespace dali #endif // DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cu b/dali/kernels/signal/wavelet/wavelet_gpu.cu index 9829ef1ca8..8bbb9df961 100644 --- a/dali/kernels/signal/wavelet/wavelet_gpu.cu +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cu @@ -42,8 +42,7 @@ __global__ void ComputeWavelet(const SampleDesc* sample_data, W wavelet) { auto x = std::pow(2.0, a); if (a == 0.0) { shm[b_id] = sample.in[t_id]; - } - else { + } else { shm[b_id] = x * sample.in[t_id]; shm[1024] = std::pow(2.0, a / 2.0); } @@ -53,8 +52,7 @@ __global__ void ComputeWavelet(const SampleDesc* sample_data, W wavelet) { auto b = sample.b[i]; if (b == 0.0) { sample.out[out_id] = wavelet(shm[b_id]); - } - else { + } else { sample.out[out_id] = wavelet(shm[b_id] - b); } if (a != 0.0) { @@ -66,7 +64,8 @@ __global__ void ComputeWavelet(const SampleDesc* sample_data, W wavelet) { // translate input range information to input samples template __global__ void ComputeInputSamples(const SampleDesc* sample_data) { - const int64_t t_id = blockDim.x * blockDim.y * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x; + const int64_t block_size = blockDim.x * blockDim.y; + const int64_t t_id = block_size * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x; auto& sample = sample_data[blockIdx.y]; if (t_id >= sample.size_in) return; sample.in[t_id] = sample.span.begin + (T)t_id / sample.span.sampling_rate; @@ -107,7 +106,8 @@ DLL_PUBLIC void WaveletGpu::Run(KernelContext &ctx, sample.b = b.tensor_data(i); sample.size_b = b.shape.tensor_size(i); sample.span = span; - sample.size_in = std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate) + 1; + sample.size_in = + std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate) + 1; sample.in = ctx.scratchpad->AllocateGPU(sample.size_in); max_size_in = std::max(max_size_in, sample.size_in); } @@ -133,9 +133,11 @@ TensorListShape<> WaveletGpu::GetOutputShape(const TensorListShape<> &a_sh TensorListShape<> out_shape(N, 3); TensorShape<> tshape; for (int i = 0; i < N; i++) { - // output tensor will be 3-dimensional of shape: + // output tensor will be 3-dimensional of shape: // a coeffs x b coeffs x signal samples - tshape = TensorShape<>({a_shape.tensor_shape(i).num_elements(), b_shape.tensor_shape(i).num_elements(), in_size}); + tshape = TensorShape<>({a_shape.tensor_shape(i).num_elements(), + b_shape.tensor_shape(i).num_elements(), + in_size}); out_shape.set_tensor_shape(i, tshape); } return out_shape; diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cuh b/dali/kernels/signal/wavelet/wavelet_gpu.cuh index 0026ff58c4..45c54e8b25 100644 --- a/dali/kernels/signal/wavelet/wavelet_gpu.cuh +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cuh @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_ -#define DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_ +#ifndef DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_ +#define DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_ #include #include +#include #include "dali/core/common.h" #include "dali/core/error_handling.h" #include "dali/core/format.h" @@ -26,13 +27,16 @@ // makes sure both tensors have the same number of samples and // that they're one-dimensional -#define ENFORCE_SHAPES(a_shape, b_shape) do { \ -DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(),"a and b tensors must have the same amount of samples."); \ -for (int i = 0; i < a_shape.num_samples(); ++i) { \ - DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, "Tensor of a coeffs should be 1-dimensional."); \ - DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, "Tensor of b coeffs should be 1-dimensional."); \ -} \ -} while(0); +#define ENFORCE_SHAPES(a_shape, b_shape) do { \ +DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(), \ + "a and b tensors must have the same amount of samples."); \ +for (int i = 0; i < a_shape.num_samples(); ++i) { \ + DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, \ + "Tensor of a coeffs should be 1-dimensional."); \ + DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, \ + "Tensor of b coeffs should be 1-dimensional."); \ +} \ +} while (0); namespace dali { namespace kernels { @@ -90,6 +94,7 @@ class DLL_PUBLIC WaveletGpu { static TensorListShape<> GetOutputShape(const TensorListShape<> &a_shape, const TensorListShape<> &b_shape, const WaveletSpan &span); + private: W wavelet_; }; @@ -98,4 +103,4 @@ class DLL_PUBLIC WaveletGpu { } // namespace kernels } // namespace dali -#endif // DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_ +#endif // DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_ diff --git a/dali/operators/signal/wavelet/wavelet_name.h b/dali/operators/signal/wavelet/wavelet_name.h index c9c94dff06..454022587f 100644 --- a/dali/operators/signal/wavelet/wavelet_name.h +++ b/dali/operators/signal/wavelet/wavelet_name.h @@ -29,6 +29,6 @@ enum DALIWaveletName { DALI_FBSP = 5 }; -} // namespace dali +} // namespace dali -#endif // DALI_OPERATORS_SIGNAL_WAVELET_NAME_H_ \ No newline at end of file +#endif // DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_NAME_H_ diff --git a/dali/operators/signal/wavelet/wavelet_run.h b/dali/operators/signal/wavelet/wavelet_run.h index 7ce6b3a5f4..6218bbb9ba 100644 --- a/dali/operators/signal/wavelet/wavelet_run.h +++ b/dali/operators/signal/wavelet/wavelet_run.h @@ -57,28 +57,34 @@ void RunForName(const DALIWaveletName &name, const std::vector &args) { switch (name) { case DALIWaveletName::DALI_HAAR: - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + using kernels::signal::HaarWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); break; case DALIWaveletName::DALI_MEY: - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + using kernels::signal::MeyerWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); break; case DALIWaveletName::DALI_MEXH: - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + using kernels::signal::MexicanHatWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); break; case DALIWaveletName::DALI_MORL: - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + using kernels::signal::MorletWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); break; case DALIWaveletName::DALI_SHAN: - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + using kernels::signal::ShannonWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); break; case DALIWaveletName::DALI_FBSP: - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + using kernels::signal::FbspWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); break; default: throw new std::invalid_argument("Unknown wavelet name."); } } -} // namespace dali +} // namespace dali -#endif // DALI_OPERATORS_SIGNAL_WAVELET_RUN_H_ \ No newline at end of file +#endif // DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_RUN_H_ diff --git a/dali/python/nvidia/dali/types.py b/dali/python/nvidia/dali/types.py index 405165847a..5d56077ee6 100644 --- a/dali/python/nvidia/dali/types.py +++ b/dali/python/nvidia/dali/types.py @@ -16,7 +16,8 @@ from enum import Enum, unique import re -from nvidia.dali.backend_impl.types import DALIDataType, DALIImageType, DALIInterpType, DALIWaveletName +from nvidia.dali.backend_impl.types import DALIDataType, DALIImageType, \ + DALIInterpType, DALIWaveletName # TODO: Handle forwarding imports from backend_impl from nvidia.dali.backend_impl.types import * # noqa: F401, F403 @@ -63,7 +64,8 @@ def _not_implemented(val): DALIDataType.DATA_TYPE: ("nvidia.dali.types.DALIDataType", lambda x: DALIDataType(int(x))), DALIDataType.INTERP_TYPE: ("nvidia.dali.types.DALIInterpType", lambda x: DALIInterpType(int(x))), - DALIDataType.WAVELET_NAME: ("nvidia.dali.types.DALIWaveletName", lambda x: DALIWaveletName(int(x))), + DALIDataType.WAVELET_NAME: + ("nvidia.dali.types.DALIWaveletName", lambda x: DALIWaveletName(int(x))), DALIDataType.TENSOR_LAYOUT: (":ref:`layout str`", lambda x: str(x)), DALIDataType.PYTHON_OBJECT: ("object", lambda x: x), DALIDataType._TENSOR_LAYOUT_VEC: From 2875c95a6190223eeee9994fcf885794fdf6f0f4 Mon Sep 17 00:00:00 2001 From: JakubO Date: Tue, 13 Jun 2023 03:22:12 +0200 Subject: [PATCH 09/14] replace MeyerWavelet with GaussianWavelet --- dali/kernels/signal/wavelet/mother_wavelet.cu | 47 ++++++++++++++----- .../kernels/signal/wavelet/mother_wavelet.cuh | 10 ++-- dali/kernels/signal/wavelet/wavelet_gpu.cu | 4 +- dali/operators/signal/wavelet/wavelet_name.h | 2 +- dali/operators/signal/wavelet/wavelet_run.h | 6 +-- dali/python/backend_impl.cc | 5 +- 6 files changed, 49 insertions(+), 25 deletions(-) diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cu b/dali/kernels/signal/wavelet/mother_wavelet.cu index 66d83936ea..6e1c027996 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cu +++ b/dali/kernels/signal/wavelet/mother_wavelet.cu @@ -43,24 +43,47 @@ template class HaarWavelet; template class HaarWavelet; template -MeyerWavelet::MeyerWavelet(const std::vector &args) { - if (args.size() != 0) { - throw new std::invalid_argument("MeyerWavelet doesn't accept any arguments."); +GaussianWavelet::GaussianWavelet(const std::vector &args) { + if (args.size() != 1) { + throw new std::invalid_argument("GaussianWavelet accepts exactly one argument - n."); } + if (args[0] < 1.0 || args[0] > 8.0) { + throw new std::invalid_argument( + "GaussianWavelet's argument n should be integer from range [1,8]."); + } + this->n = args[0]; } template -__device__ T MeyerWavelet::operator()(const T &t) const { - T tt = t - 0.5; - T psi1 = (4.0/(3.0*M_PI)*tt*std::cos((2.0*M_PI)/3.0*tt)-1.0/M_PI*std::sin((4.0*M_PI)/3.0*tt))/ - (tt-16.0/9.0*std::pow(tt, 3.0)); - T psi2 = (8.0/(3.0*M_PI)*tt*std::cos(8.0*M_PI/3.0*tt)+1.0/M_PI*std::sin((4.0*M_PI)/3.0)*tt)/ - (tt-64.0/9.0*std::pow(tt, 3.0)); - return psi1 + psi2; +__device__ T GaussianWavelet::operator()(const T &t) const { + T expTerm = std::exp(-std::pow(t, 2.0)); + T sqrtTerm = 1.2533141373155001; // std::sqrt(M_PI/2.0) + switch (static_cast(n)) { + case 1: + return -2.0*t*expTerm/std::sqrt(sqrtTerm); + case 2: + return (-4.0*std::pow(t, 2.0)+2.0)*expTerm/std::sqrt(3.0*sqrtTerm); + case 3: + return (8.0*std::pow(t, 3.0)-12.0*t)*expTerm/std::sqrt(15.0*sqrtTerm); + case 4: + return (-48.0*std::pow(t, 2.0)+16.0*std::pow(t, 4.0)+12.0)*expTerm/std::sqrt(105.0*sqrtTerm); + case 5: + return (-32.0*std::pow(t, 5.0)+160.0*std::pow(t, 3.0)-120.0*t)* + expTerm/std::sqrt(945.0*sqrtTerm); + case 6: + return (-64.0*std::pow(t, 6.0)+480.0*std::pow(t, 4.0)-720.0*std::pow(t, 2.0)+120.0)* + expTerm/std::sqrt(10395.0*sqrtTerm); + case 7: + return (128.0*std::pow(t, 7.0)-1344.0*std::pow(t, 5.0)+3360.0*std::pow(t, 3.0)-1680.0*t)* + expTerm/std::sqrt(135135.0*sqrtTerm); + case 8: + return (256.0*std::pow(t, 8.0)-3584.0*std::pow(t, 6.0)+13440.0*std::pow(t, 4.0)-13440.0* + std::pow(t, 2.0)+1680.0)*expTerm/std::sqrt(2027025.0*sqrtTerm); + } } -template class MeyerWavelet; -template class MeyerWavelet; +template class GaussianWavelet; +template class GaussianWavelet; template MexicanHatWavelet::MexicanHatWavelet(const std::vector &args) { diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cuh b/dali/kernels/signal/wavelet/mother_wavelet.cuh index 70ae5d8ca2..1e618be69f 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cuh +++ b/dali/kernels/signal/wavelet/mother_wavelet.cuh @@ -44,15 +44,17 @@ class HaarWavelet { }; template -class MeyerWavelet { +class GaussianWavelet { static_assert(std::is_floating_point::value, "Data type should be floating point"); public: - MeyerWavelet() = default; - explicit MeyerWavelet(const std::vector &args); - ~MeyerWavelet() = default; + GaussianWavelet() = default; + explicit GaussianWavelet(const std::vector &args); + ~GaussianWavelet() = default; __device__ T operator()(const T &t) const; + private: + T n; }; template diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cu b/dali/kernels/signal/wavelet/wavelet_gpu.cu index 8bbb9df961..a5ab81a5df 100644 --- a/dali/kernels/signal/wavelet/wavelet_gpu.cu +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cu @@ -145,8 +145,8 @@ TensorListShape<> WaveletGpu::GetOutputShape(const TensorListShape<> &a_sh template class WaveletGpu; template class WaveletGpu; -template class WaveletGpu; -template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; template class WaveletGpu; template class WaveletGpu; template class WaveletGpu; diff --git a/dali/operators/signal/wavelet/wavelet_name.h b/dali/operators/signal/wavelet/wavelet_name.h index 454022587f..5e53713bba 100644 --- a/dali/operators/signal/wavelet/wavelet_name.h +++ b/dali/operators/signal/wavelet/wavelet_name.h @@ -22,7 +22,7 @@ namespace dali { */ enum DALIWaveletName { DALI_HAAR = 0, - DALI_MEY = 1, + DALI_GAUS = 1, DALI_MEXH = 2, DALI_MORL = 3, DALI_SHAN = 4, diff --git a/dali/operators/signal/wavelet/wavelet_run.h b/dali/operators/signal/wavelet/wavelet_run.h index 6218bbb9ba..def362cff7 100644 --- a/dali/operators/signal/wavelet/wavelet_run.h +++ b/dali/operators/signal/wavelet/wavelet_run.h @@ -60,9 +60,9 @@ void RunForName(const DALIWaveletName &name, using kernels::signal::HaarWavelet; RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); break; - case DALIWaveletName::DALI_MEY: - using kernels::signal::MeyerWavelet; - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + case DALIWaveletName::DALI_GAUS: + using kernels::signal::GaussianWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); break; case DALIWaveletName::DALI_MEXH: using kernels::signal::MexicanHatWavelet; diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index 504be8def0..1706532d8c 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -122,8 +122,7 @@ py::dict ArrayInterfaceRepr(Tensor &t) { d["shape"] = py::tuple(py_shape(t)); // tuple of (raw_data_pointer, if_data_is_read_only) tup[0] = py::reinterpret_borrow(PyLong_FromVoidPtr(t.raw_mutable_data())); - // if we make it readonly, it prevents us from sharing memory with PyTorch tensor - tup[1] = false; + tup[1] = true; d["data"] = tup; if (std::is_same::value) { // see https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html @@ -1721,7 +1720,7 @@ PYBIND11_MODULE(backend_impl, m) { // DALIWaveletName py::enum_(types_m, "DALIWaveletName", "Wavelet name\n") .value("HAAR", DALI_HAAR) - .value("MEY", DALI_MEY) + .value("GAUS", DALI_GAUS) .value("MEXH", DALI_MEXH) .value("MORL", DALI_MORL) .value("SHAN", DALI_SHAN) From 0efec3db243c9303c3d355e65df4e7cc4e5d3223 Mon Sep 17 00:00:00 2001 From: JakubO Date: Mon, 3 Jul 2023 02:55:11 +0200 Subject: [PATCH 10/14] Fix wavelet exceptions Wavelet constructor exceptions are now being handled correctly. Morlet wavelet C argument has been removed. --- dali/kernels/signal/wavelet/mother_wavelet.cu | 19 +++++++++---------- .../kernels/signal/wavelet/mother_wavelet.cuh | 3 --- dali/operators/signal/wavelet/wavelet_run.h | 2 +- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cu b/dali/kernels/signal/wavelet/mother_wavelet.cu index 6e1c027996..232c183a0c 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cu +++ b/dali/kernels/signal/wavelet/mother_wavelet.cu @@ -24,7 +24,7 @@ namespace signal { template HaarWavelet::HaarWavelet(const std::vector &args) { if (args.size() != 0) { - throw new std::invalid_argument("HaarWavelet doesn't accept any arguments."); + throw std::invalid_argument("HaarWavelet doesn't accept any arguments."); } } @@ -45,10 +45,10 @@ template class HaarWavelet; template GaussianWavelet::GaussianWavelet(const std::vector &args) { if (args.size() != 1) { - throw new std::invalid_argument("GaussianWavelet accepts exactly one argument - n."); + throw std::invalid_argument("GaussianWavelet accepts exactly one argument - n."); } if (args[0] < 1.0 || args[0] > 8.0) { - throw new std::invalid_argument( + throw std::invalid_argument( "GaussianWavelet's argument n should be integer from range [1,8]."); } this->n = args[0]; @@ -88,7 +88,7 @@ template class GaussianWavelet; template MexicanHatWavelet::MexicanHatWavelet(const std::vector &args) { if (args.size() != 1) { - throw new std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma."); + throw std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma."); } this->sigma = args[0]; } @@ -104,15 +104,14 @@ template class MexicanHatWavelet; template MorletWavelet::MorletWavelet(const std::vector &args) { - if (args.size() != 1) { - throw new std::invalid_argument("MorletWavelet accepts exactly 1 argument - C."); + if (args.size() != 0) { + throw std::invalid_argument("MorletWavelet doesn't accept any arguments."); } - this->C = args[0]; } template __device__ T MorletWavelet::operator()(const T &t) const { - return C * std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t); + return std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t); } template class MorletWavelet; @@ -121,7 +120,7 @@ template class MorletWavelet; template ShannonWavelet::ShannonWavelet(const std::vector &args) { if (args.size() != 2) { - throw new std::invalid_argument( + throw std::invalid_argument( "ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order."); } this->fb = args[0]; @@ -140,7 +139,7 @@ template class ShannonWavelet; template FbspWavelet::FbspWavelet(const std::vector &args) { if (args.size() != 3) { - throw new std::invalid_argument( + throw std::invalid_argument( "FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order."); } this->m = args[0]; diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cuh b/dali/kernels/signal/wavelet/mother_wavelet.cuh index 1e618be69f..9cbd81592b 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cuh +++ b/dali/kernels/signal/wavelet/mother_wavelet.cuh @@ -82,9 +82,6 @@ class MorletWavelet { ~MorletWavelet() = default; __device__ T operator()(const T &t) const; - - private: - T C; }; template diff --git a/dali/operators/signal/wavelet/wavelet_run.h b/dali/operators/signal/wavelet/wavelet_run.h index def362cff7..cf386f8f59 100644 --- a/dali/operators/signal/wavelet/wavelet_run.h +++ b/dali/operators/signal/wavelet/wavelet_run.h @@ -81,7 +81,7 @@ void RunForName(const DALIWaveletName &name, RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); break; default: - throw new std::invalid_argument("Unknown wavelet name."); + throw std::invalid_argument("Unknown wavelet name."); } } From 1ed22bcfc9f9b406d097454608d624770fc5e0e6 Mon Sep 17 00:00:00 2001 From: JakubO Date: Wed, 5 Jul 2023 01:06:10 +0200 Subject: [PATCH 11/14] Add CWT operator docstr --- dali/operators/signal/wavelet/cwt_op_gpu.cu | 27 +++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/dali/operators/signal/wavelet/cwt_op_gpu.cu b/dali/operators/signal/wavelet/cwt_op_gpu.cu index 3cea5427d6..5cb2ab626b 100644 --- a/dali/operators/signal/wavelet/cwt_op_gpu.cu +++ b/dali/operators/signal/wavelet/cwt_op_gpu.cu @@ -25,8 +25,31 @@ namespace dali { -DALI_SCHEMA(Cwt).DocStr("by MW").NumInput(1).NumOutput(1).AddArg("a", "costam", - type2id::value); +DALI_SCHEMA(Cwt) + .DocStr(R"(Performs continuous wavelet transform on a 1D signal (for example, audio). + +Result values of transform are computed for all specified scales. +Input data is expected to be one channel (shape being ``(nsamples,)``, ``(nsamples, 1)`` +) of type float32.)") + .NumInput(1) + .NumOutput(1) + .AddArg("a", R"(List of scale coefficients of type float32.)", DALIDataType::DALI_FLOAT_VEC) + .AddArg("wavelet", R"(Name of mother wavelet. Currently supported wavelets' names are: +- HAAR - Haar wavelet +- GAUS - Gaussian wavelet +- MEXH - Mexican hat wavelet +- MORL - Morlet wavelet +- SHAN - Shannon wavleet +- FBSP - Frequency B-spline wavelet)", DALIDataType::DALI_WAVELET_NAME) + .AddArg("wavelet_args", R"(Additional arguments for mother wavelet. They are passed +as list of float32 values. +- HAAR - none +- GAUS - n (order of derivative) +- MEXH - sigma +- MORL - none +- SHAN - fb (bandwidth parameter > 0), fc (center frequency > 0) +- FBSP - m (order parameter >= 1), fb (bandwidth parameter > 0), fc (center frequency > 0) +)", DALIDataType::DALI_FLOAT_VEC); template struct CwtImplGPU : public OpImplBase { From 1cdc5e7c7f318171c8be9a9004a4522be2c593c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Wdowski?= Date: Fri, 8 Sep 2023 22:28:03 +0200 Subject: [PATCH 12/14] WIP --- dali/kernels/signal/wavelet/cwt_args.h | 9 ++++-- dali/kernels/signal/wavelet/cwt_gpu.cu | 12 ++++---- dali/kernels/signal/wavelet/cwt_gpu.h | 4 +-- dali/operators/signal/fft/power_spectrum.h | 13 ++++---- dali/operators/signal/wavelet/cwt_op.h | 13 ++++++-- dali/operators/signal/wavelet/cwt_op_gpu.cu | 31 ++++++++++++-------- dali/operators/signal/wavelet/wavelet_name.h | 12 ++++---- 7 files changed, 54 insertions(+), 40 deletions(-) diff --git a/dali/kernels/signal/wavelet/cwt_args.h b/dali/kernels/signal/wavelet/cwt_args.h index b61d064a9e..9a38b8d006 100644 --- a/dali/kernels/signal/wavelet/cwt_args.h +++ b/dali/kernels/signal/wavelet/cwt_args.h @@ -15,17 +15,20 @@ #ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ #define DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ +#include +#include "dali/operators/signal/wavelet/wavelet_name.h" + namespace dali { namespace kernels { namespace signal { -namespace wavelet { template struct CwtArgs { - T a; + std::vector a; + dali::DALIWaveletName wavelet; + std::vector wavelet_args; }; -} // namespace wavelet } // namespace signal } // namespace kernels } // namespace dali diff --git a/dali/kernels/signal/wavelet/cwt_gpu.cu b/dali/kernels/signal/wavelet/cwt_gpu.cu index a15f82929a..cfca159483 100644 --- a/dali/kernels/signal/wavelet/cwt_gpu.cu +++ b/dali/kernels/signal/wavelet/cwt_gpu.cu @@ -19,13 +19,12 @@ #include "dali/core/error_handling.h" #include "dali/core/format.h" #include "dali/kernels/kernel.h" -#include "dali/kernels/signal/wavelets/cwt_args.h" -#include "dali/kernels/signal/wavelets/cwt_gpu.h" +#include "dali/kernels/signal/wavelet/cwt_args.h" +#include "dali/kernels/signal/wavelet/cwt_gpu.h" namespace dali { namespace kernels { namespace signal { -namespace wavelet { template struct SampleDesc { @@ -35,7 +34,7 @@ struct SampleDesc { }; template -__global__ void CwtKernel(const SampleDesc *sample_data, CwtArgs args) { +__global__ void CwtKernel(const SampleDesc *sample_data) { const int64_t block_size = blockDim.y * blockDim.x; const int64_t grid_size = gridDim.x * block_size; const int sample_idx = blockIdx.y; @@ -44,7 +43,7 @@ __global__ void CwtKernel(const SampleDesc *sample_data, CwtArgs args) { const int64_t tid = threadIdx.y * blockDim.x + threadIdx.x; for (int64_t idx = offset + tid; idx < sample.size; idx += grid_size) { - sample.out[idx] = sample.in[idx] * args.a; + sample.out[idx] = sample.in[idx]; } } @@ -86,13 +85,12 @@ void CwtGpu::Run(KernelContext &context, const OutListGPU<<>>(sample_data_gpu, args); + CwtKernel<<>>(sample_data_gpu); } template class CwtGpu; template class CwtGpu; -} // namespace wavelet } // namespace signal } // namespace kernels } // namespace dali diff --git a/dali/kernels/signal/wavelet/cwt_gpu.h b/dali/kernels/signal/wavelet/cwt_gpu.h index 62f9cef738..35a494aca6 100644 --- a/dali/kernels/signal/wavelet/cwt_gpu.h +++ b/dali/kernels/signal/wavelet/cwt_gpu.h @@ -21,12 +21,11 @@ #include "dali/core/format.h" #include "dali/core/util.h" #include "dali/kernels/kernel.h" -#include "dali/kernels/signal/wavelets/cwt_args.h" +#include "dali/kernels/signal/wavelet/cwt_args.h" namespace dali { namespace kernels { namespace signal { -namespace wavelet { template class DLL_PUBLIC CwtGpu { @@ -42,7 +41,6 @@ class DLL_PUBLIC CwtGpu { const InListGPU &in, const CwtArgs &args); }; -} // namespace wavelet } // namespace signal } // namespace kernels } // namespace dali diff --git a/dali/operators/signal/fft/power_spectrum.h b/dali/operators/signal/fft/power_spectrum.h index 170818187a..65117ef1c8 100644 --- a/dali/operators/signal/fft/power_spectrum.h +++ b/dali/operators/signal/fft/power_spectrum.h @@ -28,8 +28,7 @@ namespace dali { template class PowerSpectrum : public Operator { public: - explicit PowerSpectrum(const OpSpec &spec) - : Operator(spec) { + explicit PowerSpectrum(const OpSpec &spec) : Operator(spec) { fft_args_.nfft = spec.HasArgument("nfft") ? spec.GetArgument("nfft") : -1; fft_args_.transform_axis = spec.GetArgument("axis"); int power = spec.GetArgument("power"); @@ -41,13 +40,17 @@ class PowerSpectrum : public Operator { fft_args_.spectrum_type = kernels::signal::fft::FFT_SPECTRUM_POWER; break; default: - DALI_FAIL(make_string("Power argument should be either `2` for power spectrum or `1` " - "for complex magnitude. Received: ", power)); + DALI_FAIL( + make_string("Power argument should be either `2` for power spectrum or `1` " + "for complex magnitude. Received: ", + power)); } } protected: - bool CanInferOutputs() const override { return true; } + bool CanInferOutputs() const override { + return true; + } bool SetupImpl(std::vector &output_desc, const Workspace &ws) override; void RunImpl(Workspace &ws) override; diff --git a/dali/operators/signal/wavelet/cwt_op.h b/dali/operators/signal/wavelet/cwt_op.h index 3d6e439d49..9a3ecc169b 100644 --- a/dali/operators/signal/wavelet/cwt_op.h +++ b/dali/operators/signal/wavelet/cwt_op.h @@ -18,8 +18,10 @@ #include #include #include "dali/core/common.h" -#include "dali/kernels/signal/wavelets/cwt_args.h" +#include "dali/kernels/kernel_manager.h" +#include "dali/kernels/signal/wavelet/cwt_args.h" #include "dali/pipeline/operator/common.h" +#include "dali/pipeline/operator/op_spec.h" #include "dali/pipeline/operator/operator.h" #include "dali/pipeline/util/operator_impl_utils.h" @@ -32,7 +34,12 @@ class Cwt : public Operator { if (!spec.HasArgument("a")) { DALI_ENFORCE("`a` argument must be provided."); } - args_.a = spec.GetArgument("a"); + args_.a = spec.GetRepeatedArgument("a"); + if (!spec.HasArgument("wavelet")) { + DALI_ENFORCE("`wavelet` argument must be provided."); + } + args_.wavelet = spec.GetArgument("wavelet"); + args_.wavelet_args = spec.GetRepeatedArgument("wavelet_args"); } protected: @@ -54,7 +61,7 @@ class Cwt : public Operator { using Operator::RunImpl; kernels::KernelManager kmgr_; - kernels::signal::wavelets::CwtArgs args_; + kernels::signal::CwtArgs args_; std::unique_ptr> impl_; DALIDataType type_ = DALI_NO_TYPE; diff --git a/dali/operators/signal/wavelet/cwt_op_gpu.cu b/dali/operators/signal/wavelet/cwt_op_gpu.cu index 5cb2ab626b..535079d39f 100644 --- a/dali/operators/signal/wavelet/cwt_op_gpu.cu +++ b/dali/operators/signal/wavelet/cwt_op_gpu.cu @@ -18,30 +18,33 @@ #include "dali/core/static_switch.h" #include "dali/kernels/kernel_manager.h" #include "dali/kernels/kernel_params.h" -#include "dali/kernels/signal/wavelets/cwt_args.h" -#include "dali/kernels/signal/wavelets/cwt_gpu.h" -#include "dali/operators/signal/wavelets/cwt_op.h" +#include "dali/kernels/signal/wavelet/cwt_args.h" +#include "dali/kernels/signal/wavelet/cwt_gpu.h" +#include "dali/operators/signal/wavelet/cwt_op.h" +#include "dali/pipeline/data/types.h" #include "dali/pipeline/data/views.h" +#include "dali/pipeline/operator/op_schema.h" namespace dali { DALI_SCHEMA(Cwt) - .DocStr(R"(Performs continuous wavelet transform on a 1D signal (for example, audio). + .DocStr(R"(Performs continuous wavelet transform on a 1D signal (for example, audio). Result values of transform are computed for all specified scales. Input data is expected to be one channel (shape being ``(nsamples,)``, ``(nsamples, 1)`` ) of type float32.)") - .NumInput(1) - .NumOutput(1) - .AddArg("a", R"(List of scale coefficients of type float32.)", DALIDataType::DALI_FLOAT_VEC) - .AddArg("wavelet", R"(Name of mother wavelet. Currently supported wavelets' names are: + .NumInput(1) + .NumOutput(1) + .AddArg("a", R"(List of scale coefficients of type float32.)", DALIDataType::DALI_FLOAT_VEC) + .AddArg("wavelet", R"(Name of mother wavelet. Currently supported wavelets' names are: - HAAR - Haar wavelet - GAUS - Gaussian wavelet - MEXH - Mexican hat wavelet - MORL - Morlet wavelet - SHAN - Shannon wavleet -- FBSP - Frequency B-spline wavelet)", DALIDataType::DALI_WAVELET_NAME) - .AddArg("wavelet_args", R"(Additional arguments for mother wavelet. They are passed +- FBSP - Frequency B-spline wavelet)", + DALIDataType::DALI_WAVELET_NAME) + .AddArg("wavelet_args", R"(Additional arguments for mother wavelet. They are passed as list of float32 values. - HAAR - none - GAUS - n (order of derivative) @@ -49,13 +52,15 @@ as list of float32 values. - MORL - none - SHAN - fb (bandwidth parameter > 0), fc (center frequency > 0) - FBSP - m (order parameter >= 1), fb (bandwidth parameter > 0), fc (center frequency > 0) -)", DALIDataType::DALI_FLOAT_VEC); +)", + DALIDataType::DALI_FLOAT_VEC); template struct CwtImplGPU : public OpImplBase { public: - using CwtArgs = kernels::signal::wavelets::CwtArgs; - using CwtKernel = kernels::signal::wavelets::CwtGpu; + using CwtArgs = kernels::signal::CwtArgs; + using CwtKernel = kernels::signal::CwtGpu; + using WaveletKernel = kernels::signal::CwtGpu; explicit CwtImplGPU(CwtArgs args) : args_(std::move(args)) { kmgr_cwt_.Resize(1); diff --git a/dali/operators/signal/wavelet/wavelet_name.h b/dali/operators/signal/wavelet/wavelet_name.h index 5e53713bba..e101040b14 100644 --- a/dali/operators/signal/wavelet/wavelet_name.h +++ b/dali/operators/signal/wavelet/wavelet_name.h @@ -21,12 +21,12 @@ namespace dali { * @brief Supported wavelet names */ enum DALIWaveletName { - DALI_HAAR = 0, - DALI_GAUS = 1, - DALI_MEXH = 2, - DALI_MORL = 3, - DALI_SHAN = 4, - DALI_FBSP = 5 + DALI_HAAR = 0, + DALI_GAUS = 1, + DALI_MEXH = 2, + DALI_MORL = 3, + DALI_SHAN = 4, + DALI_FBSP = 5 }; } // namespace dali From 101efc4db4fe585fbfc714c7dcc36c72af84a9d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Wdowski?= Date: Tue, 12 Sep 2023 14:02:46 +0200 Subject: [PATCH 13/14] Good size but full of zeros --- dali/kernels/signal/wavelet/wavelet_gpu.cuh | 39 ++--- dali/operators/signal/wavelet/cwt_op.h | 17 +- dali/operators/signal/wavelet/cwt_op_gpu.cu | 127 ++++++++++---- dali/operators/signal/wavelet/wavelet_run.h | 176 ++++++++++++++------ dali/test/python/operator_2/test_cwt.py | 42 +++++ 5 files changed, 293 insertions(+), 108 deletions(-) create mode 100644 dali/test/python/operator_2/test_cwt.py diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cuh b/dali/kernels/signal/wavelet/wavelet_gpu.cuh index 45c54e8b25..49a03d8c7b 100644 --- a/dali/kernels/signal/wavelet/wavelet_gpu.cuh +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cuh @@ -27,16 +27,17 @@ // makes sure both tensors have the same number of samples and // that they're one-dimensional -#define ENFORCE_SHAPES(a_shape, b_shape) do { \ -DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(), \ - "a and b tensors must have the same amount of samples."); \ -for (int i = 0; i < a_shape.num_samples(); ++i) { \ - DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, \ - "Tensor of a coeffs should be 1-dimensional."); \ - DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, \ - "Tensor of b coeffs should be 1-dimensional."); \ -} \ -} while (0); +#define ENFORCE_SHAPES(a_shape, b_shape) \ + do { \ + DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(), \ + "a and b tensors must have the same amount of samples."); \ + for (int i = 0; i < a_shape.num_samples(); ++i) { \ + DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, \ + "Tensor of a coeffs should be 1-dimensional."); \ + DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, \ + "Tensor of b coeffs should be 1-dimensional."); \ + } \ + } while (0); namespace dali { namespace kernels { @@ -70,26 +71,20 @@ struct SampleDesc { WaveletSpan span; }; -template class W > +template class W> class DLL_PUBLIC WaveletGpu { public: - static_assert(std::is_floating_point::value, - "Only floating point types are supported"); + static_assert(std::is_floating_point::value, "Only floating point types are supported"); DLL_PUBLIC WaveletGpu() = default; DLL_PUBLIC ~WaveletGpu() = default; - DLL_PUBLIC KernelRequirements Setup(KernelContext &context, - const InListGPU &a, - const InListGPU &b, - const WaveletSpan &span, + DLL_PUBLIC KernelRequirements Setup(KernelContext &context, const InListGPU &a, + const InListGPU &b, const WaveletSpan &span, const std::vector &args); - DLL_PUBLIC void Run(KernelContext &ctx, - OutListGPU &out, - const InListGPU &a, - const InListGPU &b, - const WaveletSpan &span); + DLL_PUBLIC void Run(KernelContext &ctx, OutListGPU &out, const InListGPU &a, + const InListGPU &b, const WaveletSpan &span); static TensorListShape<> GetOutputShape(const TensorListShape<> &a_shape, const TensorListShape<> &b_shape, diff --git a/dali/operators/signal/wavelet/cwt_op.h b/dali/operators/signal/wavelet/cwt_op.h index 9a3ecc169b..59c211cc7f 100644 --- a/dali/operators/signal/wavelet/cwt_op.h +++ b/dali/operators/signal/wavelet/cwt_op.h @@ -18,8 +18,15 @@ #include #include #include "dali/core/common.h" +#include "dali/core/static_switch.h" #include "dali/kernels/kernel_manager.h" +#include "dali/kernels/kernel_params.h" #include "dali/kernels/signal/wavelet/cwt_args.h" +#include "dali/kernels/signal/wavelet/cwt_gpu.h" +#include "dali/operators/signal/wavelet/cwt_op.h" +#include "dali/operators/signal/wavelet/wavelet_name.h" +#include "dali/pipeline/data/types.h" +#include "dali/pipeline/data/views.h" #include "dali/pipeline/operator/common.h" #include "dali/pipeline/operator/op_spec.h" #include "dali/pipeline/operator/operator.h" @@ -47,15 +54,9 @@ class Cwt : public Operator { return true; } - bool SetupImpl(std::vector &output_desc, const Workspace &ws) override { - assert(impl_ != nullptr); - return impl_->SetupImpl(output_desc, ws); - } + bool SetupImpl(std::vector &output_desc, const Workspace &ws) override; - void RunImpl(Workspace &ws) override { - assert(impl_ != nullptr); - impl_->RunImpl(ws); - } + void RunImpl(Workspace &ws) override; USE_OPERATOR_MEMBERS(); using Operator::RunImpl; diff --git a/dali/operators/signal/wavelet/cwt_op_gpu.cu b/dali/operators/signal/wavelet/cwt_op_gpu.cu index 535079d39f..7d3fad3f95 100644 --- a/dali/operators/signal/wavelet/cwt_op_gpu.cu +++ b/dali/operators/signal/wavelet/cwt_op_gpu.cu @@ -15,12 +15,16 @@ #include #include #include +#include "dali/core/dev_buffer.h" #include "dali/core/static_switch.h" +#include "dali/core/tensor_shape.h" #include "dali/kernels/kernel_manager.h" #include "dali/kernels/kernel_params.h" #include "dali/kernels/signal/wavelet/cwt_args.h" #include "dali/kernels/signal/wavelet/cwt_gpu.h" +#include "dali/kernels/signal/wavelet/wavelet_gpu.cuh" #include "dali/operators/signal/wavelet/cwt_op.h" +#include "dali/operators/signal/wavelet/wavelet_run.h" #include "dali/pipeline/data/types.h" #include "dali/pipeline/data/views.h" #include "dali/pipeline/operator/op_schema.h" @@ -60,48 +64,113 @@ struct CwtImplGPU : public OpImplBase { public: using CwtArgs = kernels::signal::CwtArgs; using CwtKernel = kernels::signal::CwtGpu; - using WaveletKernel = kernels::signal::CwtGpu; + + template