New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DCT 1D CPU kernel #1569
Add DCT 1D CPU kernel #1569
Changes from 6 commits
0330834
db9d0fb
5d2ba76
93418ec
e21c094
3952d38
2c94cea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
collect_headers(DALI_INST_HDRS PARENT_SCOPE) | ||
collect_sources(DALI_KERNEL_SRCS PARENT_SCOPE) | ||
collect_test_sources(DALI_KERNEL_TEST_SRCS PARENT_SCOPE) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_KERNELS_SIGNAL_DCT_DCT_ARGS_H_ | ||
#define DALI_KERNELS_SIGNAL_DCT_DCT_ARGS_H_ | ||
|
||
namespace dali { | ||
namespace kernels { | ||
namespace signal { | ||
namespace dct { | ||
|
||
/** | ||
* @brief DCT kernel arguments | ||
*/ | ||
struct DctArgs { | ||
/// @brief DCT type. Supported types are 1, 2, 3, 4 | ||
/// @remarks DCT type I requires the input data length to be > 1. | ||
int dct_type = 2; | ||
|
||
/// @brief Index of the dimension to be transformed. Last dimension by default | ||
int axis = -1; | ||
|
||
/// @brief If true, the output DCT matrix will be normalized to be orthogonal | ||
/// @remarks Normalization is not supported for DCT type I | ||
bool normalize = false; | ||
|
||
/// @brief Number of coefficients we are interested in calculating. | ||
/// By default, ndct = in_shape[axis] | ||
int ndct = -1; | ||
|
||
inline bool operator==(const DctArgs& oth) const { | ||
return dct_type == oth.dct_type && | ||
axis == oth.axis && | ||
normalize == oth.normalize; | ||
} | ||
|
||
inline bool operator!=(const DctArgs& oth) const { | ||
return !operator==(oth); | ||
} | ||
}; | ||
|
||
} // namespace dct | ||
} // namespace signal | ||
} // namespace kernels | ||
} // namespace dali | ||
|
||
#endif // DALI_KERNELS_SIGNAL_DCT_DCT_ARGS_H_ |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,214 @@ | ||||||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||||||
// | ||||||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
// you may not use this file except in compliance with the License. | ||||||
// You may obtain a copy of the License at | ||||||
// | ||||||
// http://www.apache.org/licenses/LICENSE-2.0 | ||||||
// | ||||||
// Unless required by applicable law or agreed to in writing, software | ||||||
// distributed under the License is distributed on an "AS IS" BASIS, | ||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
// See the License for the specific language governing permissions and | ||||||
// limitations under the License. | ||||||
|
||||||
#include "dali/kernels/signal/dct/dct_cpu.h" | ||||||
#include <cmath> | ||||||
jantonguirao marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
#include "dali/core/common.h" | ||||||
#include "dali/core/convert.h" | ||||||
#include "dali/core/error_handling.h" | ||||||
#include "dali/core/format.h" | ||||||
#include "dali/core/util.h" | ||||||
#include "dali/kernels/common/for_axis.h" | ||||||
#include "dali/kernels/common/utils.h" | ||||||
#include "dali/kernels/kernel.h" | ||||||
|
||||||
namespace dali { | ||||||
namespace kernels { | ||||||
namespace signal { | ||||||
namespace dct { | ||||||
|
||||||
namespace { | ||||||
|
||||||
template <typename T> | ||||||
void FillCosineTableTypeI(T *table, int64_t input_length, int64_t ndct, bool normalize) { | ||||||
assert(input_length > 1); | ||||||
assert(!normalize); | ||||||
T phase_mul = M_PI / (input_length - 1); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
int64_t idx = 0; | ||||||
for (int64_t k = 0; k < ndct; k++) { | ||||||
table[idx++] = T(0.5); // n = 0 | ||||||
for (int64_t n = 1; n < input_length-1; n++) { | ||||||
table[idx++] = std::cos(phase_mul * k * n); | ||||||
} | ||||||
table[idx++] = k % 2 == 0 ? T(0.5) : -T(0.5); // n = input_length - 1 | ||||||
} | ||||||
} | ||||||
|
||||||
template <typename T> | ||||||
void FillCosineTableTypeII(T *table, int64_t input_length, int64_t ndct, bool normalize) { | ||||||
T phase_mul = M_PI / input_length; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
T factor_k_0 = 1, factor_k_i = 1; | ||||||
if (normalize) { | ||||||
factor_k_i = std::sqrt(2 / T(input_length)); | ||||||
factor_k_0 = factor_k_i / std::sqrt(T(2)); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't it just:
Suggested change
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
} | ||||||
int64_t idx = 0; | ||||||
for (int64_t k = 0; k < ndct; k++) { | ||||||
T norm_factor = (k == 0) ? factor_k_0 : factor_k_i; | ||||||
for (int64_t n = 0; n < input_length; n++) { | ||||||
table[idx++] = norm_factor * std::cos(phase_mul * (n + T(0.5)) * k); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's a one-off calculation, then the code below give more precision:
Suggested change
|
||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
|
||||||
template <typename T> | ||||||
void FillCosineTableTypeIII(T *table, int64_t input_length, int64_t ndct, bool normalize) { | ||||||
T phase_mul = M_PI / input_length; | ||||||
T factor_n_0 = 0.5, factor_n_i = 1; | ||||||
if (normalize) { | ||||||
factor_n_i = std::sqrt(T(2) / input_length); | ||||||
factor_n_0 = factor_n_i / std::sqrt(T(2)); | ||||||
} | ||||||
int64_t idx = 0; | ||||||
for (int64_t k = 0; k < ndct; k++) { | ||||||
table[idx++] = factor_n_0; // n = 0 | ||||||
for (int64_t n = 1; n < input_length; n++) { | ||||||
table[idx++] = factor_n_i * std::cos(phase_mul * n * (k + T(0.5))); | ||||||
} | ||||||
} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above. |
||||||
} | ||||||
|
||||||
|
||||||
template <typename T> | ||||||
void FillCosineTableTypeIV(T *table, int64_t input_length, int64_t ndct, bool normalize) { | ||||||
T phase_mul = M_PI / input_length; | ||||||
T factor = normalize ? std::sqrt(T(2)/input_length) : T(1); | ||||||
int64_t idx = 0; | ||||||
for (int64_t k = 0; k < ndct; k++) { | ||||||
for (int64_t n = 0; n < input_length; n++) { | ||||||
table[idx++] = factor * std::cos(phase_mul * (n + T(0.5)) * (k + T(0.5))); | ||||||
} | ||||||
} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ...and again. |
||||||
} | ||||||
|
||||||
|
||||||
template <typename T> | ||||||
void FillCosineTable(T *table, int64_t input_length, int64_t ndct, int dct_type, bool normalize) { | ||||||
switch (dct_type) { | ||||||
case 1: | ||||||
FillCosineTableTypeI(table, input_length, ndct, normalize); | ||||||
break; | ||||||
case 2: | ||||||
FillCosineTableTypeII(table, input_length, ndct, normalize); | ||||||
break; | ||||||
case 3: | ||||||
FillCosineTableTypeIII(table, input_length, ndct, normalize); | ||||||
break; | ||||||
case 4: | ||||||
FillCosineTableTypeIV(table, input_length, ndct, normalize); | ||||||
break; | ||||||
default: | ||||||
assert(false); | ||||||
} | ||||||
} | ||||||
|
||||||
} // namespace | ||||||
|
||||||
template <typename OutputType, typename InputType, int Dims> | ||||||
Dct1DCpu<OutputType, InputType, Dims>::~Dct1DCpu() = default; | ||||||
|
||||||
template <typename OutputType, typename InputType, int Dims> | ||||||
KernelRequirements | ||||||
Dct1DCpu<OutputType, InputType, Dims>::Setup(KernelContext &context, | ||||||
const InTensorCPU<InputType, Dims> &in, | ||||||
const DctArgs &original_args) { | ||||||
const auto &in_shape = in.shape; | ||||||
DALI_ENFORCE(in_shape.size() == Dims); | ||||||
|
||||||
auto args = original_args; | ||||||
args.axis = args.axis >= 0 ? args.axis : Dims - 1; | ||||||
DALI_ENFORCE(args.axis >= 0 && args.axis < Dims, | ||||||
make_string("Axis is out of bounds: ", args.axis)); | ||||||
int64_t n = in.shape[args.axis]; | ||||||
|
||||||
if (args.dct_type == 1) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have that check repeated in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's ok - this is a public API, so we throw - in FillCosineTable we assert (it should not be reachable with such parameter, regardless of what external caller does).. |
||||||
DALI_ENFORCE(n > 1, "DCT type I requires an input length > 1"); | ||||||
if (args.normalize) { | ||||||
DALI_WARN("DCT type-I does not support orthogonal normalization. Ignoring"); | ||||||
args.normalize = false; | ||||||
} | ||||||
} | ||||||
|
||||||
if (args.ndct <= 0 || args.ndct > n) { | ||||||
args.ndct = n; | ||||||
} | ||||||
|
||||||
auto out_shape = in.shape; | ||||||
out_shape[args.axis] = args.ndct; | ||||||
|
||||||
if (cos_table_.empty() || args != args_) { | ||||||
auto cos_table_sz = n * args.ndct; | ||||||
cos_table_.resize(cos_table_sz); | ||||||
FillCosineTable(cos_table_.data(), n, args.ndct, args.dct_type, args.normalize); | ||||||
args_ = args; | ||||||
} | ||||||
|
||||||
KernelRequirements req; | ||||||
req.output_shapes = {TensorListShape<DynamicDimensions>({out_shape})}; | ||||||
return req; | ||||||
} | ||||||
|
||||||
template <typename OutputType, typename InputType, int Dims> | ||||||
void Dct1DCpu<OutputType, InputType, Dims>::Run(KernelContext &context, | ||||||
const OutTensorCPU<OutputType, Dims> &out, | ||||||
const InTensorCPU<InputType, Dims> &in, | ||||||
const DctArgs &args) { | ||||||
(void)args; | ||||||
assert(args_.axis >= 0 && args_.axis < Dims); | ||||||
const auto n = in.shape[args_.axis]; | ||||||
|
||||||
assert(args_.dct_type >= 1 && args_.dct_type <= 4); | ||||||
|
||||||
auto in_shape = in.shape; | ||||||
auto in_strides = GetStrides(in_shape); | ||||||
auto out_shape = out.shape; | ||||||
auto out_strides = GetStrides(out_shape); | ||||||
|
||||||
ForAxis( | ||||||
out.data, in.data, out_shape.data(), out_strides.data(), in_shape.data(), in_strides.data(), | ||||||
args_.axis, out.dim(), | ||||||
[this]( | ||||||
OutputType *out_data, const InputType *in_data, int64_t out_size, int64_t out_stride, | ||||||
int64_t in_size, int64_t in_stride) { | ||||||
int64_t out_idx = 0; | ||||||
for (int64_t k = 0; k < out_size; k++) { | ||||||
OutputType out_val = 0; | ||||||
const auto *cos_table_row = cos_table_.data() + k * in_size; | ||||||
int64_t in_idx = 0; | ||||||
for (int64_t n = 0; n < in_size; n++) { | ||||||
OutputType in_val = in_data[in_idx]; | ||||||
in_idx += in_stride; | ||||||
out_val += in_val * cos_table_row[n]; | ||||||
} | ||||||
out_data[out_idx] = out_val; | ||||||
out_idx += out_stride; | ||||||
} | ||||||
}); | ||||||
} | ||||||
|
||||||
template class Dct1DCpu<float, float, 1>; | ||||||
template class Dct1DCpu<float, float, 2>; | ||||||
template class Dct1DCpu<float, float, 3>; | ||||||
template class Dct1DCpu<float, float, 4>; | ||||||
|
||||||
template class Dct1DCpu<double, double, 1>; | ||||||
template class Dct1DCpu<double, double, 2>; | ||||||
template class Dct1DCpu<double, double, 3>; | ||||||
template class Dct1DCpu<double, double, 4>; | ||||||
|
||||||
} // namespace dct | ||||||
} // namespace signal | ||||||
} // namespace kernels | ||||||
} // namespace dali |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_KERNELS_SIGNAL_DCT_DCT_CPU_H_ | ||
#define DALI_KERNELS_SIGNAL_DCT_DCT_CPU_H_ | ||
|
||
#include <memory> | ||
#include <vector> | ||
#include "dali/core/common.h" | ||
#include "dali/core/error_handling.h" | ||
#include "dali/core/format.h" | ||
#include "dali/core/util.h" | ||
#include "dali/kernels/kernel.h" | ||
#include "dali/kernels/signal/dct/dct_args.h" | ||
|
||
namespace dali { | ||
namespace kernels { | ||
namespace signal { | ||
namespace dct { | ||
|
||
/** | ||
* @brief Discrete Cosine Transform 1D CPU kernel. | ||
* Performs a DCT transformation over a single dimension in a multi-dimensional input. | ||
* | ||
* @remarks It supports DCT types I, II, III and IV decribed here: | ||
* https://en.wikipedia.org/wiki/Discrete_cosine_transform | ||
* DCT generally stands for type II and inverse DCT stands for DCT type III | ||
* | ||
* @see DCTArgs | ||
*/ | ||
template <typename OutputType = float, typename InputType = float, int Dims = 2> | ||
class DLL_PUBLIC Dct1DCpu { | ||
public: | ||
static_assert(std::is_floating_point<InputType>::value, | ||
"Data type should be floating point"); | ||
static_assert(std::is_same<OutputType, InputType>::value, | ||
"Data type conversion is not supported"); | ||
|
||
DLL_PUBLIC ~Dct1DCpu(); | ||
|
||
DLL_PUBLIC KernelRequirements Setup(KernelContext &context, | ||
const InTensorCPU<InputType, Dims> &in, | ||
const DctArgs &args); | ||
|
||
DLL_PUBLIC void Run(KernelContext &context, | ||
const OutTensorCPU<OutputType, Dims> &out, | ||
const InTensorCPU<InputType, Dims> &in, | ||
const DctArgs &args); | ||
private: | ||
std::vector<OutputType> cos_table_; | ||
DctArgs args_; | ||
}; | ||
|
||
} // namespace dct | ||
} // namespace signal | ||
} // namespace kernels | ||
} // namespace dali | ||
|
||
#endif // DALI_KERNELS_SIGNAL_DCT_DCT_CPU_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What this value maps to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't map to anything. See:
https://en.wikipedia.org/wiki/Discrete_cosine_transform#Formal_definition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this reference should be put here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's in the operator documentation, but ok, I'll add it here as well