diff --git a/dali/kernels/signal/CMakeLists.txt b/dali/kernels/signal/CMakeLists.txt index 431ae39629..74ca2e8970 100644 --- a/dali/kernels/signal/CMakeLists.txt +++ b/dali/kernels/signal/CMakeLists.txt @@ -17,6 +17,7 @@ add_subdirectory(decibel) if (BUILD_FFTS) add_subdirectory(fft) endif() +add_subdirectory(wavelet) add_subdirectory(window) collect_headers(DALI_INST_HDRS PARENT_SCOPE) diff --git a/dali/kernels/signal/wavelet/CMakeLists.txt b/dali/kernels/signal/wavelet/CMakeLists.txt new file mode 100644 index 0000000000..f3a24faa7c --- /dev/null +++ b/dali/kernels/signal/wavelet/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +collect_headers(DALI_INST_HDRS PARENT_SCOPE) +collect_sources(DALI_KERNEL_SRCS PARENT_SCOPE) +collect_test_sources(DALI_KERNEL_TEST_SRCS PARENT_SCOPE) \ No newline at end of file diff --git a/dali/kernels/signal/wavelet/cwt_args.h b/dali/kernels/signal/wavelet/cwt_args.h new file mode 100644 index 0000000000..9a38b8d006 --- /dev/null +++ b/dali/kernels/signal/wavelet/cwt_args.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ +#define DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ + +#include +#include "dali/operators/signal/wavelet/wavelet_name.h" + +namespace dali { +namespace kernels { +namespace signal { + +template +struct CwtArgs { + std::vector a; + dali::DALIWaveletName wavelet; + std::vector wavelet_args; +}; + +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ diff --git a/dali/kernels/signal/wavelet/cwt_gpu.cu b/dali/kernels/signal/wavelet/cwt_gpu.cu new file mode 100644 index 0000000000..cfca159483 --- /dev/null +++ b/dali/kernels/signal/wavelet/cwt_gpu.cu @@ -0,0 +1,96 @@ +// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "dali/core/common.h" +#include "dali/core/error_handling.h" +#include "dali/core/format.h" +#include "dali/kernels/kernel.h" +#include "dali/kernels/signal/wavelet/cwt_args.h" +#include "dali/kernels/signal/wavelet/cwt_gpu.h" + +namespace dali { +namespace kernels { +namespace signal { + +template +struct SampleDesc { + const T *in = nullptr; + T *out = nullptr; + int64_t size = 0; +}; + +template +__global__ void CwtKernel(const SampleDesc *sample_data) { + const int64_t block_size = blockDim.y * blockDim.x; + const int64_t grid_size = gridDim.x * block_size; + const int sample_idx = blockIdx.y; + const auto sample = sample_data[sample_idx]; + const int64_t offset = block_size * blockIdx.x; + const int64_t tid = threadIdx.y * blockDim.x + threadIdx.x; + + for (int64_t idx = offset + tid; idx < sample.size; idx += grid_size) { + sample.out[idx] = sample.in[idx]; + } +} + +template +CwtGpu::~CwtGpu() = default; + +template +KernelRequirements CwtGpu::Setup(KernelContext &context, + const InListGPU &in) { + auto out_shape = in.shape; + const size_t num_samples = in.size(); + ScratchpadEstimator se; + se.add>(num_samples); + se.add>(num_samples); + KernelRequirements req; + req.scratch_sizes = se.sizes; + req.output_shapes = {out_shape}; + return req; +} + +template +void CwtGpu::Run(KernelContext &context, const OutListGPU &out, + const InListGPU &in, const CwtArgs &args) { + auto num_samples = in.size(); + auto *sample_data = context.scratchpad->AllocateHost>(num_samples); + + for (int i = 0; i < num_samples; i++) { + auto &sample = sample_data[i]; + sample.out = out.tensor_data(i); + sample.in = in.tensor_data(i); + sample.size = volume(in.tensor_shape(i)); + assert(sample.size == volume(out.tensor_shape(i))); + } + + auto *sample_data_gpu = context.scratchpad->AllocateGPU>(num_samples); + CUDA_CALL(cudaMemcpyAsync(sample_data_gpu, sample_data, num_samples * sizeof(SampleDesc), + cudaMemcpyHostToDevice, context.gpu.stream)); + + dim3 block(32, 32); + auto blocks_per_sample = std::max(32, 1024 / num_samples); + dim3 grid(blocks_per_sample, num_samples); + CwtKernel<<>>(sample_data_gpu); +} + +template class CwtGpu; +template class CwtGpu; + +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/kernels/signal/wavelet/cwt_gpu.h b/dali/kernels/signal/wavelet/cwt_gpu.h new file mode 100644 index 0000000000..35a494aca6 --- /dev/null +++ b/dali/kernels/signal/wavelet/cwt_gpu.h @@ -0,0 +1,48 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ +#define DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ + +#include +#include "dali/core/common.h" +#include "dali/core/error_handling.h" +#include "dali/core/format.h" +#include "dali/core/util.h" +#include "dali/kernels/kernel.h" +#include "dali/kernels/signal/wavelet/cwt_args.h" + +namespace dali { +namespace kernels { +namespace signal { + +template +class DLL_PUBLIC CwtGpu { + public: + static_assert(std::is_floating_point::value, "Only floating point types are supported"); + + DLL_PUBLIC ~CwtGpu(); + + DLL_PUBLIC KernelRequirements Setup(KernelContext &context, + const InListGPU &in); + + DLL_PUBLIC void Run(KernelContext &context, const OutListGPU &out, + const InListGPU &in, const CwtArgs &args); +}; + +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cu b/dali/kernels/signal/wavelet/mother_wavelet.cu new file mode 100644 index 0000000000..232c183a0c --- /dev/null +++ b/dali/kernels/signal/wavelet/mother_wavelet.cu @@ -0,0 +1,161 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "dali/kernels/signal/wavelet/mother_wavelet.cuh" +#include "dali/core/math_util.h" + +namespace dali { +namespace kernels { +namespace signal { + +template +HaarWavelet::HaarWavelet(const std::vector &args) { + if (args.size() != 0) { + throw std::invalid_argument("HaarWavelet doesn't accept any arguments."); + } +} + +template +__device__ T HaarWavelet::operator()(const T &t) const { + if (0.0 <= t && t < 0.5) { + return 1.0; + } + if (0.5 <= t && t < 1.0) { + return -1.0; + } + return 0.0; +} + +template class HaarWavelet; +template class HaarWavelet; + +template +GaussianWavelet::GaussianWavelet(const std::vector &args) { + if (args.size() != 1) { + throw std::invalid_argument("GaussianWavelet accepts exactly one argument - n."); + } + if (args[0] < 1.0 || args[0] > 8.0) { + throw std::invalid_argument( + "GaussianWavelet's argument n should be integer from range [1,8]."); + } + this->n = args[0]; +} + +template +__device__ T GaussianWavelet::operator()(const T &t) const { + T expTerm = std::exp(-std::pow(t, 2.0)); + T sqrtTerm = 1.2533141373155001; // std::sqrt(M_PI/2.0) + switch (static_cast(n)) { + case 1: + return -2.0*t*expTerm/std::sqrt(sqrtTerm); + case 2: + return (-4.0*std::pow(t, 2.0)+2.0)*expTerm/std::sqrt(3.0*sqrtTerm); + case 3: + return (8.0*std::pow(t, 3.0)-12.0*t)*expTerm/std::sqrt(15.0*sqrtTerm); + case 4: + return (-48.0*std::pow(t, 2.0)+16.0*std::pow(t, 4.0)+12.0)*expTerm/std::sqrt(105.0*sqrtTerm); + case 5: + return (-32.0*std::pow(t, 5.0)+160.0*std::pow(t, 3.0)-120.0*t)* + expTerm/std::sqrt(945.0*sqrtTerm); + case 6: + return (-64.0*std::pow(t, 6.0)+480.0*std::pow(t, 4.0)-720.0*std::pow(t, 2.0)+120.0)* + expTerm/std::sqrt(10395.0*sqrtTerm); + case 7: + return (128.0*std::pow(t, 7.0)-1344.0*std::pow(t, 5.0)+3360.0*std::pow(t, 3.0)-1680.0*t)* + expTerm/std::sqrt(135135.0*sqrtTerm); + case 8: + return (256.0*std::pow(t, 8.0)-3584.0*std::pow(t, 6.0)+13440.0*std::pow(t, 4.0)-13440.0* + std::pow(t, 2.0)+1680.0)*expTerm/std::sqrt(2027025.0*sqrtTerm); + } +} + +template class GaussianWavelet; +template class GaussianWavelet; + +template +MexicanHatWavelet::MexicanHatWavelet(const std::vector &args) { + if (args.size() != 1) { + throw std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma."); + } + this->sigma = args[0]; +} + +template +__device__ T MexicanHatWavelet::operator()(const T &t) const { + return 2.0/(std::sqrt(3.0*sigma)*std::pow(M_PI, 0.25))*(1.0-std::pow(t/sigma, 2.0))* + std::exp(-std::pow(t, 2.0)/(2.0*std::pow(sigma, 2.0))); +} + +template class MexicanHatWavelet; +template class MexicanHatWavelet; + +template +MorletWavelet::MorletWavelet(const std::vector &args) { + if (args.size() != 0) { + throw std::invalid_argument("MorletWavelet doesn't accept any arguments."); + } +} + +template +__device__ T MorletWavelet::operator()(const T &t) const { + return std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t); +} + +template class MorletWavelet; +template class MorletWavelet; + +template +ShannonWavelet::ShannonWavelet(const std::vector &args) { + if (args.size() != 2) { + throw std::invalid_argument( + "ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order."); + } + this->fb = args[0]; + this->fc = args[1]; +} + +template +__device__ T ShannonWavelet::operator()(const T &t) const { + auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb); + return t == 0.0 ? res : res*std::sin(t*fb*(T)(M_PI))/(t*fb*(T)(M_PI)); +} + +template class ShannonWavelet; +template class ShannonWavelet; + +template +FbspWavelet::FbspWavelet(const std::vector &args) { + if (args.size() != 3) { + throw std::invalid_argument( + "FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order."); + } + this->m = args[0]; + this->fb = args[1]; + this->fc = args[2]; +} + +template +__device__ T FbspWavelet::operator()(const T &t) const { + auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb); + return t == 0.0 ? res : res*std::pow(std::sin((T)(M_PI)*t*fb/m)/((T)(M_PI)*t*fb/m), m); +} + +template class FbspWavelet; +template class FbspWavelet; + +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cuh b/dali/kernels/signal/wavelet/mother_wavelet.cuh new file mode 100644 index 0000000000..9cbd81592b --- /dev/null +++ b/dali/kernels/signal/wavelet/mother_wavelet.cuh @@ -0,0 +1,124 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ +#define DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ + +#include + +#include "dali/core/common.h" +#include "dali/core/error_handling.h" +#include "dali/core/format.h" +#include "dali/core/util.h" +#include "dali/kernels/kernel.h" + +namespace dali { +namespace kernels { +namespace signal { + +// wavelets are represented by functors +// they can store any necessary parameters +// they must overload () operator + +template +class HaarWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + HaarWavelet() = default; + explicit HaarWavelet(const std::vector &args); + ~HaarWavelet() = default; + + __device__ T operator()(const T &t) const; +}; + +template +class GaussianWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + GaussianWavelet() = default; + explicit GaussianWavelet(const std::vector &args); + ~GaussianWavelet() = default; + + __device__ T operator()(const T &t) const; + private: + T n; +}; + +template +class MexicanHatWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + MexicanHatWavelet() = default; + explicit MexicanHatWavelet(const std::vector &args); + ~MexicanHatWavelet() = default; + + __device__ T operator()(const T &t) const; + + private: + T sigma; +}; + +template +class MorletWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + MorletWavelet() = default; + explicit MorletWavelet(const std::vector &args); + ~MorletWavelet() = default; + + __device__ T operator()(const T &t) const; +}; + +template +class ShannonWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + ShannonWavelet() = default; + explicit ShannonWavelet(const std::vector &args); + ~ShannonWavelet() = default; + + __device__ T operator()(const T &t) const; + + private: + T fb; + T fc; +}; + +template +class FbspWavelet { + static_assert(std::is_floating_point::value, + "Data type should be floating point"); + public: + FbspWavelet() = default; + explicit FbspWavelet(const std::vector &args); + ~FbspWavelet() = default; + + __device__ T operator()(const T &t) const; + + private: + T m; + T fb; + T fc; +}; + +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cu b/dali/kernels/signal/wavelet/wavelet_gpu.cu new file mode 100644 index 0000000000..a5ab81a5df --- /dev/null +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cu @@ -0,0 +1,161 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/kernels/signal/wavelet/wavelet_gpu.cuh" +#include +#include +#include +#include "dali/core/common.h" +#include "dali/core/error_handling.h" +#include "dali/core/format.h" +#include "dali/kernels/kernel.h" +#include "dali/kernels/signal/wavelet/mother_wavelet.cuh" +#include "dali/core/tensor_shape.h" + +namespace dali { +namespace kernels { +namespace signal { + +// computes wavelet value for each sample in specified range, +// and each a and b coeff +template class W > +__global__ void ComputeWavelet(const SampleDesc* sample_data, W wavelet) { + // id inside block + const int64_t b_id = threadIdx.y * blockDim.x + threadIdx.x; + // wavelet sample id + const int64_t t_id = blockDim.x * blockDim.y * blockIdx.x + b_id; + auto& sample = sample_data[blockIdx.z]; + if (t_id >= sample.size_in) return; + __shared__ T shm[1025]; + auto a = sample.a[blockIdx.y]; + auto x = std::pow(2.0, a); + if (a == 0.0) { + shm[b_id] = sample.in[t_id]; + } else { + shm[b_id] = x * sample.in[t_id]; + shm[1024] = std::pow(2.0, a / 2.0); + } + __syncthreads(); + for (int i = 0; i < sample.size_b; ++i) { + const int64_t out_id = blockIdx.y * sample.size_b * sample.size_in + i * sample.size_in + t_id; + auto b = sample.b[i]; + if (b == 0.0) { + sample.out[out_id] = wavelet(shm[b_id]); + } else { + sample.out[out_id] = wavelet(shm[b_id] - b); + } + if (a != 0.0) { + sample.out[out_id] *= shm[1024]; + } + } +} + +// translate input range information to input samples +template +__global__ void ComputeInputSamples(const SampleDesc* sample_data) { + const int64_t block_size = blockDim.x * blockDim.y; + const int64_t t_id = block_size * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x; + auto& sample = sample_data[blockIdx.y]; + if (t_id >= sample.size_in) return; + sample.in[t_id] = sample.span.begin + (T)t_id / sample.span.sampling_rate; +} + +template class W > +DLL_PUBLIC KernelRequirements WaveletGpu::Setup(KernelContext &context, + const InListGPU &a, + const InListGPU &b, + const WaveletSpan &span, + const std::vector &args) { + ENFORCE_SHAPES(a.shape, b.shape); + auto out_shape = this->GetOutputShape(a.shape, b.shape, span); + KernelRequirements req; + req.output_shapes = {out_shape}; + wavelet_ = W(args); + return req; +} + +template class W > +DLL_PUBLIC void WaveletGpu::Run(KernelContext &ctx, + OutListGPU &out, + const InListGPU &a, + const InListGPU &b, + const WaveletSpan &span) { + ENFORCE_SHAPES(a.shape, b.shape); + + auto num_samples = a.num_samples(); + std::vector> sample_data = std::vector>(num_samples); + int64_t max_size_in = 0, max_size_a = 0; + + for (int i = 0; i < num_samples; i++) { + auto &sample = sample_data[i]; + sample.out = out.tensor_data(i); + sample.a = a.tensor_data(i); + sample.size_a = a.shape.tensor_size(i); + max_size_a = std::max(max_size_a, sample.size_a); + sample.b = b.tensor_data(i); + sample.size_b = b.shape.tensor_size(i); + sample.span = span; + sample.size_in = + std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate) + 1; + sample.in = ctx.scratchpad->AllocateGPU(sample.size_in); + max_size_in = std::max(max_size_in, sample.size_in); + } + + auto* sample_data_gpu = std::get<0>(ctx.scratchpad->ToContiguousGPU(ctx.gpu.stream, sample_data)); + + dim3 block(32, 32); + const int64_t block_size = block.x * block.y; + dim3 grid1((max_size_in + block_size - 1) / block_size, num_samples); + dim3 grid2((max_size_in + block_size - 1) / block_size, max_size_a, num_samples); + + ComputeInputSamples<<>>(sample_data_gpu); + auto shared_mem_size = (block_size + 1) * sizeof(T); + ComputeWavelet<<>>(sample_data_gpu, wavelet_); +} + +template class W > +TensorListShape<> WaveletGpu::GetOutputShape(const TensorListShape<> &a_shape, + const TensorListShape<> &b_shape, + const WaveletSpan &span) { + int N = a_shape.num_samples(); + int in_size = std::ceil((span.end - span.begin) * span.sampling_rate) + 1; + TensorListShape<> out_shape(N, 3); + TensorShape<> tshape; + for (int i = 0; i < N; i++) { + // output tensor will be 3-dimensional of shape: + // a coeffs x b coeffs x signal samples + tshape = TensorShape<>({a_shape.tensor_shape(i).num_elements(), + b_shape.tensor_shape(i).num_elements(), + in_size}); + out_shape.set_tensor_shape(i, tshape); + } + return out_shape; +} + +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; + +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cuh b/dali/kernels/signal/wavelet/wavelet_gpu.cuh new file mode 100644 index 0000000000..49a03d8c7b --- /dev/null +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cuh @@ -0,0 +1,101 @@ +// Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_ +#define DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_ + +#include +#include +#include +#include "dali/core/common.h" +#include "dali/core/error_handling.h" +#include "dali/core/format.h" +#include "dali/core/util.h" +#include "dali/kernels/kernel.h" +#include "dali/kernels/signal/wavelet/mother_wavelet.cuh" + +// makes sure both tensors have the same number of samples and +// that they're one-dimensional +#define ENFORCE_SHAPES(a_shape, b_shape) \ + do { \ + DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(), \ + "a and b tensors must have the same amount of samples."); \ + for (int i = 0; i < a_shape.num_samples(); ++i) { \ + DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, \ + "Tensor of a coeffs should be 1-dimensional."); \ + DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, \ + "Tensor of b coeffs should be 1-dimensional."); \ + } \ + } while (0); + +namespace dali { +namespace kernels { +namespace signal { + +// stores data needed to reconstruct wavelet input arguments +template +struct WaveletSpan { + // lower limit of wavelet samples + T begin = -1.0; + + // upper limit of wavelet samples + T end = 1.0; + + // wavelet sampling rate (samples/s) + T sampling_rate = 1000.0; +}; + +template class WaveletSpan; +template class WaveletSpan; + +template +struct SampleDesc { + const T *a = nullptr; + int64_t size_a = 0; + const T *b = nullptr; + int64_t size_b = 0; + T *in = nullptr; + int64_t size_in = 0; + T *out = nullptr; + WaveletSpan span; +}; + +template class W> +class DLL_PUBLIC WaveletGpu { + public: + static_assert(std::is_floating_point::value, "Only floating point types are supported"); + + DLL_PUBLIC WaveletGpu() = default; + DLL_PUBLIC ~WaveletGpu() = default; + + DLL_PUBLIC KernelRequirements Setup(KernelContext &context, const InListGPU &a, + const InListGPU &b, const WaveletSpan &span, + const std::vector &args); + + DLL_PUBLIC void Run(KernelContext &ctx, OutListGPU &out, const InListGPU &a, + const InListGPU &b, const WaveletSpan &span); + + static TensorListShape<> GetOutputShape(const TensorListShape<> &a_shape, + const TensorListShape<> &b_shape, + const WaveletSpan &span); + + private: + W wavelet_; +}; + +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_ diff --git a/dali/operators/signal/CMakeLists.txt b/dali/operators/signal/CMakeLists.txt index 217f785aa2..44d93c05ba 100644 --- a/dali/operators/signal/CMakeLists.txt +++ b/dali/operators/signal/CMakeLists.txt @@ -16,6 +16,7 @@ add_subdirectory(decibel) if (BUILD_FFTS) add_subdirectory(fft) endif() +add_subdirectory(wavelet) collect_headers(DALI_INST_HDRS PARENT_SCOPE) collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) diff --git a/dali/operators/signal/fft/power_spectrum.h b/dali/operators/signal/fft/power_spectrum.h index 170818187a..65117ef1c8 100644 --- a/dali/operators/signal/fft/power_spectrum.h +++ b/dali/operators/signal/fft/power_spectrum.h @@ -28,8 +28,7 @@ namespace dali { template class PowerSpectrum : public Operator { public: - explicit PowerSpectrum(const OpSpec &spec) - : Operator(spec) { + explicit PowerSpectrum(const OpSpec &spec) : Operator(spec) { fft_args_.nfft = spec.HasArgument("nfft") ? spec.GetArgument("nfft") : -1; fft_args_.transform_axis = spec.GetArgument("axis"); int power = spec.GetArgument("power"); @@ -41,13 +40,17 @@ class PowerSpectrum : public Operator { fft_args_.spectrum_type = kernels::signal::fft::FFT_SPECTRUM_POWER; break; default: - DALI_FAIL(make_string("Power argument should be either `2` for power spectrum or `1` " - "for complex magnitude. Received: ", power)); + DALI_FAIL( + make_string("Power argument should be either `2` for power spectrum or `1` " + "for complex magnitude. Received: ", + power)); } } protected: - bool CanInferOutputs() const override { return true; } + bool CanInferOutputs() const override { + return true; + } bool SetupImpl(std::vector &output_desc, const Workspace &ws) override; void RunImpl(Workspace &ws) override; diff --git a/dali/operators/signal/wavelet/CMakeLists.txt b/dali/operators/signal/wavelet/CMakeLists.txt new file mode 100644 index 0000000000..0dba230abf --- /dev/null +++ b/dali/operators/signal/wavelet/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +collect_headers(DALI_INST_HDRS PARENT_SCOPE) +collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) +collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE) diff --git a/dali/operators/signal/wavelet/cwt_op.h b/dali/operators/signal/wavelet/cwt_op.h new file mode 100644 index 0000000000..59c211cc7f --- /dev/null +++ b/dali/operators/signal/wavelet/cwt_op.h @@ -0,0 +1,73 @@ +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_ +#define DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_ + +#include +#include +#include "dali/core/common.h" +#include "dali/core/static_switch.h" +#include "dali/kernels/kernel_manager.h" +#include "dali/kernels/kernel_params.h" +#include "dali/kernels/signal/wavelet/cwt_args.h" +#include "dali/kernels/signal/wavelet/cwt_gpu.h" +#include "dali/operators/signal/wavelet/cwt_op.h" +#include "dali/operators/signal/wavelet/wavelet_name.h" +#include "dali/pipeline/data/types.h" +#include "dali/pipeline/data/views.h" +#include "dali/pipeline/operator/common.h" +#include "dali/pipeline/operator/op_spec.h" +#include "dali/pipeline/operator/operator.h" +#include "dali/pipeline/util/operator_impl_utils.h" + +namespace dali { + +template +class Cwt : public Operator { + public: + explicit Cwt(const OpSpec &spec) : Operator(spec) { + if (!spec.HasArgument("a")) { + DALI_ENFORCE("`a` argument must be provided."); + } + args_.a = spec.GetRepeatedArgument("a"); + if (!spec.HasArgument("wavelet")) { + DALI_ENFORCE("`wavelet` argument must be provided."); + } + args_.wavelet = spec.GetArgument("wavelet"); + args_.wavelet_args = spec.GetRepeatedArgument("wavelet_args"); + } + + protected: + bool CanInferOutputs() const override { + return true; + } + + bool SetupImpl(std::vector &output_desc, const Workspace &ws) override; + + void RunImpl(Workspace &ws) override; + + USE_OPERATOR_MEMBERS(); + using Operator::RunImpl; + + kernels::KernelManager kmgr_; + kernels::signal::CwtArgs args_; + + std::unique_ptr> impl_; + DALIDataType type_ = DALI_NO_TYPE; +}; + +} // namespace dali + +#endif // DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_ diff --git a/dali/operators/signal/wavelet/cwt_op_gpu.cu b/dali/operators/signal/wavelet/cwt_op_gpu.cu new file mode 100644 index 0000000000..7d3fad3f95 --- /dev/null +++ b/dali/operators/signal/wavelet/cwt_op_gpu.cu @@ -0,0 +1,177 @@ +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "dali/core/dev_buffer.h" +#include "dali/core/static_switch.h" +#include "dali/core/tensor_shape.h" +#include "dali/kernels/kernel_manager.h" +#include "dali/kernels/kernel_params.h" +#include "dali/kernels/signal/wavelet/cwt_args.h" +#include "dali/kernels/signal/wavelet/cwt_gpu.h" +#include "dali/kernels/signal/wavelet/wavelet_gpu.cuh" +#include "dali/operators/signal/wavelet/cwt_op.h" +#include "dali/operators/signal/wavelet/wavelet_run.h" +#include "dali/pipeline/data/types.h" +#include "dali/pipeline/data/views.h" +#include "dali/pipeline/operator/op_schema.h" + +namespace dali { + +DALI_SCHEMA(Cwt) + .DocStr(R"(Performs continuous wavelet transform on a 1D signal (for example, audio). + +Result values of transform are computed for all specified scales. +Input data is expected to be one channel (shape being ``(nsamples,)``, ``(nsamples, 1)`` +) of type float32.)") + .NumInput(1) + .NumOutput(1) + .AddArg("a", R"(List of scale coefficients of type float32.)", DALIDataType::DALI_FLOAT_VEC) + .AddArg("wavelet", R"(Name of mother wavelet. Currently supported wavelets' names are: +- HAAR - Haar wavelet +- GAUS - Gaussian wavelet +- MEXH - Mexican hat wavelet +- MORL - Morlet wavelet +- SHAN - Shannon wavleet +- FBSP - Frequency B-spline wavelet)", + DALIDataType::DALI_WAVELET_NAME) + .AddArg("wavelet_args", R"(Additional arguments for mother wavelet. They are passed +as list of float32 values. +- HAAR - none +- GAUS - n (order of derivative) +- MEXH - sigma +- MORL - none +- SHAN - fb (bandwidth parameter > 0), fc (center frequency > 0) +- FBSP - m (order parameter >= 1), fb (bandwidth parameter > 0), fc (center frequency > 0) +)", + DALIDataType::DALI_FLOAT_VEC); + +template +struct CwtImplGPU : public OpImplBase { + public: + using CwtArgs = kernels::signal::CwtArgs; + using CwtKernel = kernels::signal::CwtGpu; + + template