Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CWT operator #4860

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
937b963
add MotherWavelet helper and WaveletGpu kernel
May 18, 2023
cf7b6a6
Cwt WIP
mwdowski May 18, 2023
68bb330
Merge branch 'NVIDIA:main' into wavelet-computing
kubo11 May 18, 2023
9d6e0b0
Merge pull request #2 from mwdowski/wavelet-computing
mwdowski May 18, 2023
359d79c
Merge pull request #1 from mwdowski/mwdowski
mwdowski May 18, 2023
b034619
Rename namespace
mwdowski May 18, 2023
6bb49f5
Merge branch 'main' into mwdowski
mwdowski May 18, 2023
5eed0c5
add WaveletArgs class
May 22, 2023
09196c6
Merge pull request #3 from mwdowski/wavelet-computing
kubo11 May 29, 2023
279e61b
Improve wavelet computing kernel
Jun 5, 2023
c4814f9
Optimize and remove discrete wavelets
Jun 7, 2023
11df6aa
Merge pull request #4 from mwdowski/wavelet-computing-improvements
kubo11 Jun 7, 2023
d3a8d6a
add DALIWaveletName enum
Jun 11, 2023
27cedd3
fix linting errors
Jun 11, 2023
2875c95
replace MeyerWavelet with GaussianWavelet
Jun 13, 2023
20d5d7e
Merge pull request #5 from mwdowski/wavelet-computing-improvements
kubo11 Jun 13, 2023
0efec3d
Fix wavelet exceptions
Jul 3, 2023
1ed22bc
Add CWT operator docstr
Jul 4, 2023
3c36192
Merge pull request #6 from mwdowski/wavelet-fixes
kubo11 Jul 6, 2023
1cdc5e7
WIP
mwdowski Sep 8, 2023
e99099e
Merge branch 'NVIDIA:main' into main
mwdowski Sep 8, 2023
15ce332
Merge branch 'main' into mwdowski2
mwdowski Sep 8, 2023
101efc4
Good size but full of zeros
mwdowski Sep 12, 2023
276f87e
WIP
mwdowski Sep 12, 2023
1849a30
Merge pull request #7 from mwdowski/mwdowski2
mwdowski Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dali/kernels/signal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_subdirectory(decibel)
if (BUILD_FFTS)
add_subdirectory(fft)
endif()
add_subdirectory(wavelet)
add_subdirectory(window)

collect_headers(DALI_INST_HDRS PARENT_SCOPE)
Expand Down
17 changes: 17 additions & 0 deletions dali/kernels/signal/wavelet/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

collect_headers(DALI_INST_HDRS PARENT_SCOPE)
collect_sources(DALI_KERNEL_SRCS PARENT_SCOPE)
collect_test_sources(DALI_KERNEL_TEST_SRCS PARENT_SCOPE)
36 changes: 36 additions & 0 deletions dali/kernels/signal/wavelet/cwt_args.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_
#define DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_

#include <vector>
#include "dali/operators/signal/wavelet/wavelet_name.h"

namespace dali {
namespace kernels {
namespace signal {

template <typename T = float>
struct CwtArgs {
std::vector<T> a;
dali::DALIWaveletName wavelet;
std::vector<T> wavelet_args;
};

} // namespace signal
} // namespace kernels
} // namespace dali

#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_
96 changes: 96 additions & 0 deletions dali/kernels/signal/wavelet/cwt_gpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cmath>
#include <complex>
#include <vector>
#include "dali/core/common.h"
#include "dali/core/error_handling.h"
#include "dali/core/format.h"
#include "dali/kernels/kernel.h"
#include "dali/kernels/signal/wavelet/cwt_args.h"
#include "dali/kernels/signal/wavelet/cwt_gpu.h"

namespace dali {
namespace kernels {
namespace signal {

template <typename T>
struct SampleDesc {
const T *in = nullptr;
T *out = nullptr;
int64_t size = 0;
};

template <typename T>
__global__ void CwtKernel(const SampleDesc<T> *sample_data) {
const int64_t block_size = blockDim.y * blockDim.x;
const int64_t grid_size = gridDim.x * block_size;
const int sample_idx = blockIdx.y;
JanuszL marked this conversation as resolved.
Show resolved Hide resolved
const auto sample = sample_data[sample_idx];
const int64_t offset = block_size * blockIdx.x;
const int64_t tid = threadIdx.y * blockDim.x + threadIdx.x;

for (int64_t idx = offset + tid; idx < sample.size; idx += grid_size) {
sample.out[idx] = sample.in[idx];
}
}

template <typename T>
CwtGpu<T>::~CwtGpu() = default;

template <typename T>
KernelRequirements CwtGpu<T>::Setup(KernelContext &context,
const InListGPU<T, DynamicDimensions> &in) {
auto out_shape = in.shape;
const size_t num_samples = in.size();
ScratchpadEstimator se;
se.add<mm::memory_kind::host, SampleDesc<T>>(num_samples);
se.add<mm::memory_kind::device, SampleDesc<T>>(num_samples);
KernelRequirements req;
req.scratch_sizes = se.sizes;
req.output_shapes = {out_shape};
return req;
}

template <typename T>
void CwtGpu<T>::Run(KernelContext &context, const OutListGPU<T, DynamicDimensions> &out,
const InListGPU<T, DynamicDimensions> &in, const CwtArgs<T> &args) {
auto num_samples = in.size();
auto *sample_data = context.scratchpad->AllocateHost<SampleDesc<T>>(num_samples);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as dali/kernels/signal/wavelet/wavelet_gpu.cu


for (int i = 0; i < num_samples; i++) {
auto &sample = sample_data[i];
sample.out = out.tensor_data(i);
sample.in = in.tensor_data(i);
sample.size = volume(in.tensor_shape(i));
assert(sample.size == volume(out.tensor_shape(i)));
}

auto *sample_data_gpu = context.scratchpad->AllocateGPU<SampleDesc<T>>(num_samples);
CUDA_CALL(cudaMemcpyAsync(sample_data_gpu, sample_data, num_samples * sizeof(SampleDesc<T>),
cudaMemcpyHostToDevice, context.gpu.stream));

dim3 block(32, 32);
auto blocks_per_sample = std::max(32, 1024 / num_samples);
dim3 grid(blocks_per_sample, num_samples);
CwtKernel<T><<<grid, block, 0, context.gpu.stream>>>(sample_data_gpu);
}

template class CwtGpu<float>;
template class CwtGpu<double>;

} // namespace signal
} // namespace kernels
} // namespace dali
48 changes: 48 additions & 0 deletions dali/kernels/signal/wavelet/cwt_gpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_
#define DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_

#include <memory>
#include "dali/core/common.h"
#include "dali/core/error_handling.h"
#include "dali/core/format.h"
#include "dali/core/util.h"
#include "dali/kernels/kernel.h"
#include "dali/kernels/signal/wavelet/cwt_args.h"

namespace dali {
namespace kernels {
namespace signal {

template <typename T = float>
class DLL_PUBLIC CwtGpu {
public:
static_assert(std::is_floating_point<T>::value, "Only floating point types are supported");

DLL_PUBLIC ~CwtGpu();

DLL_PUBLIC KernelRequirements Setup(KernelContext &context,
const InListGPU<T, DynamicDimensions> &in);

DLL_PUBLIC void Run(KernelContext &context, const OutListGPU<T, DynamicDimensions> &out,
const InListGPU<T, DynamicDimensions> &in, const CwtArgs<T> &args);
};

} // namespace signal
} // namespace kernels
} // namespace dali

#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_
161 changes: 161 additions & 0 deletions dali/kernels/signal/wavelet/mother_wavelet.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cmath>
#include <vector>
#include "dali/kernels/signal/wavelet/mother_wavelet.cuh"
#include "dali/core/math_util.h"

namespace dali {
namespace kernels {
namespace signal {

template <typename T>
HaarWavelet<T>::HaarWavelet(const std::vector<T> &args) {
if (args.size() != 0) {
throw std::invalid_argument("HaarWavelet doesn't accept any arguments.");
}
}

template <typename T>
__device__ T HaarWavelet<T>::operator()(const T &t) const {
if (0.0 <= t && t < 0.5) {
return 1.0;
}
if (0.5 <= t && t < 1.0) {
return -1.0;
}
return 0.0;
}

template class HaarWavelet<float>;
template class HaarWavelet<double>;

template <typename T>
GaussianWavelet<T>::GaussianWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw std::invalid_argument("GaussianWavelet accepts exactly one argument - n.");
}
if (args[0] < 1.0 || args[0] > 8.0) {
throw std::invalid_argument(
"GaussianWavelet's argument n should be integer from range [1,8].");
}
this->n = args[0];
}

template <typename T>
__device__ T GaussianWavelet<T>::operator()(const T &t) const {
T expTerm = std::exp(-std::pow(t, 2.0));
T sqrtTerm = 1.2533141373155001; // std::sqrt(M_PI/2.0)
switch (static_cast<int>(n)) {
case 1:
JanuszL marked this conversation as resolved.
Show resolved Hide resolved
return -2.0*t*expTerm/std::sqrt(sqrtTerm);
case 2:
return (-4.0*std::pow(t, 2.0)+2.0)*expTerm/std::sqrt(3.0*sqrtTerm);
case 3:
return (8.0*std::pow(t, 3.0)-12.0*t)*expTerm/std::sqrt(15.0*sqrtTerm);
case 4:
return (-48.0*std::pow(t, 2.0)+16.0*std::pow(t, 4.0)+12.0)*expTerm/std::sqrt(105.0*sqrtTerm);
case 5:
return (-32.0*std::pow(t, 5.0)+160.0*std::pow(t, 3.0)-120.0*t)*
expTerm/std::sqrt(945.0*sqrtTerm);
case 6:
return (-64.0*std::pow(t, 6.0)+480.0*std::pow(t, 4.0)-720.0*std::pow(t, 2.0)+120.0)*
expTerm/std::sqrt(10395.0*sqrtTerm);
case 7:
return (128.0*std::pow(t, 7.0)-1344.0*std::pow(t, 5.0)+3360.0*std::pow(t, 3.0)-1680.0*t)*
expTerm/std::sqrt(135135.0*sqrtTerm);
case 8:
return (256.0*std::pow(t, 8.0)-3584.0*std::pow(t, 6.0)+13440.0*std::pow(t, 4.0)-13440.0*
std::pow(t, 2.0)+1680.0)*expTerm/std::sqrt(2027025.0*sqrtTerm);
}
}

template class GaussianWavelet<float>;
template class GaussianWavelet<double>;

template <typename T>
MexicanHatWavelet<T>::MexicanHatWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma.");
}
this->sigma = args[0];
}

template <typename T>
__device__ T MexicanHatWavelet<T>::operator()(const T &t) const {
return 2.0/(std::sqrt(3.0*sigma)*std::pow(M_PI, 0.25))*(1.0-std::pow(t/sigma, 2.0))*
std::exp(-std::pow(t, 2.0)/(2.0*std::pow(sigma, 2.0)));
}

template class MexicanHatWavelet<float>;
template class MexicanHatWavelet<double>;

template <typename T>
MorletWavelet<T>::MorletWavelet(const std::vector<T> &args) {
if (args.size() != 0) {
throw std::invalid_argument("MorletWavelet doesn't accept any arguments.");
}
}

template <typename T>
__device__ T MorletWavelet<T>::operator()(const T &t) const {
return std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t);
}

template class MorletWavelet<float>;
template class MorletWavelet<double>;

template <typename T>
ShannonWavelet<T>::ShannonWavelet(const std::vector<T> &args) {
if (args.size() != 2) {
throw std::invalid_argument(
"ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order.");
}
this->fb = args[0];
this->fc = args[1];
}

template <typename T>
__device__ T ShannonWavelet<T>::operator()(const T &t) const {
auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb);
return t == 0.0 ? res : res*std::sin(t*fb*(T)(M_PI))/(t*fb*(T)(M_PI));
}

template class ShannonWavelet<float>;
template class ShannonWavelet<double>;

template <typename T>
FbspWavelet<T>::FbspWavelet(const std::vector<T> &args) {
if (args.size() != 3) {
throw std::invalid_argument(
"FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order.");
}
this->m = args[0];
this->fb = args[1];
this->fc = args[2];
}

template <typename T>
__device__ T FbspWavelet<T>::operator()(const T &t) const {
auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb);
return t == 0.0 ? res : res*std::pow(std::sin((T)(M_PI)*t*fb/m)/((T)(M_PI)*t*fb/m), m);
}

template class FbspWavelet<float>;
template class FbspWavelet<double>;

} // namespace signal
} // namespace kernels
} // namespace dali
Loading