Skip to content

Commit

Permalink
Fix missing layouts in operators (#2136)
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton <janton@nvidia.com>
  • Loading branch information
jantonguirao committed Jul 31, 2020
1 parent b5e6370 commit 0362ac4
Show file tree
Hide file tree
Showing 47 changed files with 192 additions and 72 deletions.
3 changes: 2 additions & 1 deletion dali/benchmark/operator_bench.h
Expand Up @@ -29,7 +29,8 @@ class OperatorBench : public DALIBenchmark {
template <typename OutputContainer, typename OperatorPtr, typename Workspace>
void Setup(OperatorPtr &op_ptr, const OpSpec &spec, Workspace &ws, int batch_size) {
std::vector<OutputDesc> outputs;
if (op_ptr->CanInferOutputs() && op_ptr->Setup(outputs, ws)) {
bool can_infer_outs = op_ptr->CanInferOutputs();
if (op_ptr->Setup(outputs, ws) && can_infer_outs) {
int num_out = outputs.size();
for (int i = 0; i < num_out; i++) {
auto data_out = std::make_shared<OutputContainer>(batch_size);
Expand Down
17 changes: 17 additions & 0 deletions dali/core/tensor_layout_test.cc
Expand Up @@ -273,4 +273,21 @@ TEST(TensorLayout, GetLayoutMapping) {
}
}

TEST(TensorLayout, Resize) {
TensorLayout layout;
EXPECT_EQ(layout, "");

layout.resize(3);
EXPECT_EQ(layout, "???");

layout = "HW";
EXPECT_EQ(layout, "HW");

layout.resize(3);
EXPECT_EQ(layout, "HW?");

layout.resize(4, '#');
EXPECT_EQ(layout, "HW?#");
}

} // namespace dali
7 changes: 4 additions & 3 deletions dali/operators/generic/flip.cc
Expand Up @@ -60,16 +60,17 @@ template <>
void Flip<CPUBackend>::RunImpl(Workspace<CPUBackend> &ws) {
const auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);
output.SetLayout(input.GetLayout());
auto layout = input.GetLayout();
output.SetLayout(layout);
output.set_type(input.type());
output.ResizeLike(input);
auto _horizontal = GetHorizontal(ws, ws.data_idx());
auto _vertical = GetVertical(ws, ws.data_idx());
auto _depthwise = GetDepthwise(ws, ws.data_idx());
if (!_horizontal && !_vertical) {
if (!_horizontal && !_vertical && !_depthwise) {
output.Copy(input, nullptr);
} else {
RunFlip(output, input, InputLayout(ws, 0), _horizontal, _vertical, _depthwise);
RunFlip(output, input, layout, _horizontal, _vertical, _depthwise);
}
}

Expand Down
1 change: 1 addition & 0 deletions dali/operators/generic/lookup_table.cc
Expand Up @@ -20,6 +20,7 @@ template<>
void LookupTable<CPUBackend>::RunImpl(SampleWorkspace &ws) {
const auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);
output.SetLayout(input.GetLayout());
auto data_size = input.size();
TYPE_SWITCH(input.type().id(), dali::type2id, InputType,
(uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t), (
Expand Down
1 change: 1 addition & 0 deletions dali/operators/generic/lookup_table.cu
Expand Up @@ -55,6 +55,7 @@ template<>
void LookupTable<GPUBackend>::RunImpl(DeviceWorkspace &ws) {
const auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
output.SetLayout(input.GetLayout());
auto data_size = input.size();
constexpr int kThreads = 512;
const int blocks = (data_size + kThreads - 1) / kThreads;
Expand Down
1 change: 1 addition & 0 deletions dali/operators/generic/pad.cc
Expand Up @@ -217,6 +217,7 @@ template <>
void Pad<CPUBackend>::RunImpl(workspace_t<CPUBackend> &ws) {
const auto &input = ws.InputRef<CPUBackend>(0);
auto &output = ws.OutputRef<CPUBackend>(0);
output.SetLayout(input.GetLayout());
int nsamples = input.size();
int ndim = input.shape().sample_dim();
auto& thread_pool = ws.GetThreadPool();
Expand Down
1 change: 1 addition & 0 deletions dali/operators/generic/pad.cu
Expand Up @@ -57,6 +57,7 @@ template <>
void Pad<GPUBackend>::RunImpl(workspace_t<GPUBackend> &ws) {
const auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
output.SetLayout(input.GetLayout());
int ndim = input.shape().sample_dim();
TYPE_SWITCH(input.type().id(), type2id, T, PAD_SUPPORTED_TYPES, (
VALUE_SWITCH(ndim, Dims, PAD_SUPPORTED_NDIMS, (
Expand Down
3 changes: 2 additions & 1 deletion dali/operators/generic/reshape.cc
Expand Up @@ -324,7 +324,8 @@ TensorLayout Reshape<Backend>::GetOutputLayout(const Workspace &ws) const {
// layout was explicitly cleared
return TensorLayout();
}
auto in_layout = this->InputLayout(ws, 0);
auto &in = ws.template InputRef<Backend>(0);
auto in_layout = in.GetLayout();
return in_layout.ndim() == output_shape_.sample_dim() ? in_layout : TensorLayout();
}

Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/color/brightness_contrast.cc
Expand Up @@ -81,7 +81,7 @@ bool BrightnessContrastCpu::SetupImpl(std::vector<OutputDesc> &output_desc,
void BrightnessContrastCpu::RunImpl(workspace_t<CPUBackend> &ws) {
const auto &input = ws.template InputRef<CPUBackend>(0);
auto &output = ws.template OutputRef<CPUBackend>(0);
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());
auto out_shape = output.shape();
auto& tp = ws.GetThreadPool();
TYPE_SWITCH(input.type().id(), type2id, InputType, (uint8_t, int16_t, int32_t, float), (
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/color/brightness_contrast.cu
Expand Up @@ -55,7 +55,7 @@ bool BrightnessContrastGpu::SetupImpl(std::vector<OutputDesc> &output_desc,
void BrightnessContrastGpu::RunImpl(workspace_t<GPUBackend> &ws) {
const auto &input = ws.template Input<GPUBackend>(0);
auto &output = ws.template Output<GPUBackend>(0);
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());
TYPE_SWITCH(input.type().id(), type2id, InputType, (uint8_t, int16_t, int32_t, float), (
TYPE_SWITCH(output_type_, type2id, OutputType, (uint8_t, int16_t, int32_t, float), (
{
Expand Down
1 change: 1 addition & 0 deletions dali/operators/image/color/color_space_conversion.cc
Expand Up @@ -36,6 +36,7 @@ void ColorSpaceConversion<CPUBackend>::RunImpl(SampleWorkspace &ws) {
auto &output = ws.Output<CPUBackend>(0);
const auto &input_shape = input.shape();
auto output_shape = input_shape;
output.SetLayout(input.GetLayout());

const auto H = input_shape[0];
const auto W = input_shape[1];
Expand Down
4 changes: 3 additions & 1 deletion dali/operators/image/color/color_space_conversion.cu
Expand Up @@ -96,6 +96,8 @@ void ColorSpaceConversion<GPUBackend>::RunImpl(DeviceWorkspace &ws) {
DALI_ENFORCE(IsType<uint8_t>(input.type()),
"Color space conversion accept only uint8 tensors");
auto &output = ws.Output<GPUBackend>(0);
auto layout = input.GetLayout();
output.SetLayout(layout);

TensorList<CPUBackend> attr_output_cpu;

Expand All @@ -117,7 +119,7 @@ void ColorSpaceConversion<GPUBackend>::RunImpl(DeviceWorkspace &ws) {
auto stream = ws.stream();
DALI_CHECK_NPP(nppSetStream(stream));

if (InputLayout(ws, 0) == "HWC") {
if (layout == "HWC") {
// RGB -> BGR || BGR -> RGB
for (unsigned int i = 0; i < input.ntensor(); ++i) {
// image dimensions
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/color/color_twist.cc
Expand Up @@ -168,7 +168,7 @@ void ColorTwistCpu::RunImpl(workspace_t<CPUBackend> &ws) {
const auto &input = ws.template InputRef<CPUBackend>(0);
auto &output = ws.template OutputRef<CPUBackend>(0);
auto out_shape = output.shape();
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());
auto &tp = ws.GetThreadPool();
TYPE_SWITCH(input.type().id(), type2id, InputType, (uint8_t, int16_t, int32_t, float, float16), (
TYPE_SWITCH(output_type_, type2id, OutputType, (uint8_t, int16_t, int32_t, float, float16), (
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/color/color_twist.cu
Expand Up @@ -54,7 +54,7 @@ bool ColorTwistGpu::SetupImpl(std::vector<OutputDesc> &output_desc, const Device
void ColorTwistGpu::RunImpl(workspace_t<GPUBackend> &ws) {
const auto &input = ws.template Input<GPUBackend>(0);
auto &output = ws.template Output<GPUBackend>(0);
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());
TYPE_SWITCH(input.type().id(), type2id, InputType, (uint8_t, int16_t, int32_t, float), (
TYPE_SWITCH(output_type_, type2id, OutputType, (uint8_t, int16_t, int32_t, float), (
{
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/image/color/old_color_twist.cc
Expand Up @@ -217,7 +217,7 @@ void OldColorTwistBase<GPUBackend>::RunImpl(DeviceWorkspace &ws) {
DALI_ENFORCE(IsType<uint8_t>(input.type()), "Color augmentations accept only uint8 tensors");
auto &output = ws.Output<GPUBackend>(0);
output.ResizeLike(input);
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());

cudaStream_t old_stream = nppGetStream();
nppSetStream(ws.stream());
Expand Down Expand Up @@ -259,7 +259,7 @@ void OldColorTwistBase<CPUBackend>::RunImpl(SampleWorkspace &ws) {
const auto C = input_shape[2];

output.ResizeLike(input);
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());

auto pImgInp = input.template data<uint8>();
auto pImgOut = output.template mutable_data<uint8>();
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/crop/crop_mirror_normalize.h
Expand Up @@ -164,7 +164,7 @@ class CropMirrorNormalize : public Operator<Backend> {
output_type_ = input_type_;

auto in_shape = input.shape();
input_layout_ = this->InputLayout(ws, 0);
input_layout_ = input.GetLayout();
DALI_ENFORCE(ImageLayoutInfo::IsImage(input_layout_),
("Unsupported layout: '" + input_layout_.str() + "' for input 0 '" +
this->spec_.InputName(0) + "'"));
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/remap/displacement_filter_impl_cpu.h
Expand Up @@ -165,7 +165,7 @@ class DisplacementFilter<CPUBackend, Displacement, per_channel_transform>
auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);
output.ResizeLike(input);
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());
}

void SetupSharedSampleParams(SampleWorkspace &ws) override {
Expand Down
Expand Up @@ -248,7 +248,7 @@ class DisplacementFilter<GPUBackend, Displacement,
auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
output.ResizeLike(input);
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());
}

template <typename U = Displacement>
Expand Down
4 changes: 3 additions & 1 deletion dali/operators/image/remap/warp.h
Expand Up @@ -331,7 +331,9 @@ class Warp : public Operator<Backend> {
void RunImpl(Workspace &ws) override {
assert(impl_);
impl_->Run(ws);
ws.template OutputRef<Backend>(0).SetLayout(this->InputLayout(ws, 0));
auto &out = ws.template OutputRef<Backend>(0);
auto &in = ws.template InputRef<Backend>(0);
out.SetLayout(in.GetLayout());
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/resize/random_resized_crop.cc
Expand Up @@ -56,7 +56,7 @@ void RandomResizedCrop<CPUBackend>::RunImpl(SampleWorkspace &ws) {
auto &output = ws.Output<CPUBackend>(0);

RunCPU(output, input, ws.thread_idx());
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());
}

template<>
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/resize/random_resized_crop.cu
Expand Up @@ -38,7 +38,7 @@ void RandomResizedCrop<GPUBackend>::RunImpl(DeviceWorkspace &ws) {

auto &output = ws.Output<GPUBackend>(0);
RunGPU(output, input, ws.stream());
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());
}

template<>
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/resize/resize.cc
Expand Up @@ -112,7 +112,7 @@ void Resize<CPUBackend>::RunImpl(SampleWorkspace &ws) {
}

RunCPU(output, input, thread_idx);
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());

if (save_attrs_) {
auto &attr_output = ws.Output<CPUBackend>(1);
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/resize/resize.cu
Expand Up @@ -62,7 +62,7 @@ void Resize<GPUBackend>::RunImpl(DeviceWorkspace &ws) {
auto &output = ws.Output<GPUBackend>(0);

RunGPU(output, input, ws.stream());
output.SetLayout(InputLayout(ws, 0));
output.SetLayout(input.GetLayout());

// Setup and output the resize attributes if necessary
if (save_attrs_) {
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/resize/resize_crop_mirror.h
Expand Up @@ -239,7 +239,7 @@ class ResizeCropMirror : public Operator<CPUBackend>, protected ResizeCropMirror
// Resize the output & run
output.Resize(
std::vector<Index>{crop_height_[ws.data_idx()], crop_width_[ws.data_idx()], meta.C});
output.SetLayout(this->InputLayout(ws, 0));
output.SetLayout(input.GetLayout());

tl_workspace_[ws.thread_idx()].resize(meta.rsz_h*meta.rsz_w*meta.C);
DALI_CALL((*func)(
Expand Down
1 change: 1 addition & 0 deletions dali/operators/math/normalize/normalize.cc
Expand Up @@ -236,6 +236,7 @@ void Normalize<CPUBackend>::RunTyped(HostWorkspace &ws) {

auto &output = ws.OutputRef<CPUBackend>(0);
TensorListView<StorageCPU, OutputType> out_view = view<OutputType>(output);
output.SetLayout(input.GetLayout());

int nsamples = input.ntensor();
int nthreads = ws.GetThreadPool().size();
Expand Down
3 changes: 2 additions & 1 deletion dali/operators/math/normalize/normalize.h
Expand Up @@ -132,7 +132,8 @@ class NormalizeBase : public Operator<Backend> {
GetParamShapeFromAxes();
} else if (has_axis_names_arg_) {
TensorLayout names = spec.GetArgument<TensorLayout>("axis_names");
auto dim_idx = GetDimIndices(this->InputLayout(ws, 0), names);
const auto &input = ws.template InputRef<Backend>(0);
auto dim_idx = GetDimIndices(input.GetLayout(), names);
axes_ = dim_idx.to_vector();
SetAxisMask();
if (!has_tensor_mean_ && !has_tensor_stddev_)
Expand Down
1 change: 1 addition & 0 deletions dali/operators/math/normalize/normalize_gpu.cu
Expand Up @@ -207,6 +207,7 @@ void Normalize<GPUBackend>::RunTyped(DeviceWorkspace &ws) {

auto &output = ws.OutputRef<GPUBackend>(0);
TensorListView<StorageGPU, OutputType> out_view = view<OutputType>(output);
output.SetLayout(input.GetLayout());

int nsamples = input.ntensor();

Expand Down
1 change: 1 addition & 0 deletions dali/operators/random/normal_distribution_op.h
Expand Up @@ -141,6 +141,7 @@ class NormalDistributionCpu : public NormalDistribution<CPUBackend> {

protected:
void RunImpl(workspace_t<CPUBackend> &ws) override;
using NormalDistribution<CPUBackend>::RunImpl;

private:
void AssignTensorToOutput(workspace_t<CPUBackend> &ws);
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/sequence/optical_flow/optical_flow.cc
Expand Up @@ -69,7 +69,7 @@ void OpticalFlow<GPUBackend>::RunImpl(Workspace<GPUBackend> &ws) {
const auto &input = ws.Input<GPUBackend>(0);
const auto &hints = ws.Input<GPUBackend>(1);
auto &output = ws.Output<GPUBackend>(0);

output.SetLayout("HWC"); // Channels represent the two flow vector components (x and y)
// Extract calculation params
ExtractParams(input, hints);

Expand Down Expand Up @@ -109,7 +109,7 @@ void OpticalFlow<GPUBackend>::RunImpl(Workspace<GPUBackend> &ws) {
// Input is a TensorList, where every Tensor is a sequence
const auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);

output.SetLayout(input.GetLayout());

// Extract calculation params
ExtractParams(input);
Expand Down
3 changes: 2 additions & 1 deletion dali/operators/ssd/random_crop.cc
Expand Up @@ -273,6 +273,7 @@ void SSDRandomCrop<CPUBackend>::RunImpl(SampleWorkspace &ws) {
// now we know how many output bboxes there will be, we can allocate
// the output.
auto &img_out = ws.Output<CPUBackend>(0);
img_out.SetLayout(img.GetLayout());
auto &bbox_out = ws.Output<CPUBackend>(1);
auto &label_out = ws.Output<CPUBackend>(2);

Expand Down Expand Up @@ -312,7 +313,7 @@ void SSDRandomCrop<CPUBackend>::RunImpl(SampleWorkspace &ws) {

// perform the crop
detail::crop(img, {left_idx, top_idx, right_idx, bottom_idx},
ws.Output<CPUBackend>(0));
img_out);

return;
} // end num_attempts loop
Expand Down
9 changes: 7 additions & 2 deletions dali/pipeline/data/tensor_vector.h
Expand Up @@ -226,10 +226,15 @@ class TensorVector {

inline TensorLayout GetLayout() const {
if (state_ == State::contiguous) {
return tl_->GetLayout();
auto layout = tl_->GetLayout();
if (!layout.empty())
return layout;
}
if (tensors_.size() > 0) {
return tensors_[0]->GetLayout();
auto layout = tensors_[0]->GetLayout();
for (size_t i = 1; i < tensors_.size(); i++)
assert(layout == tensors_[i]->GetLayout());
return layout;
}
return {};
}
Expand Down
35 changes: 35 additions & 0 deletions dali/pipeline/executor/executor.h
Expand Up @@ -339,11 +339,36 @@ class DLL_PUBLIC Executor : public ExecutorBase, public WorkspacePolicy, public
std::mutex gpu_memory_stats_mutex_;

private:
template <typename InputRef>
static bool SetDefaultLayoutIfNeeded(InputRef &in, const OpSchema &schema, int in_idx) {
if (!in.GetLayout().empty())
return false;
auto default_layout = schema.GetInputLayout(in_idx, in.shape().sample_dim(), in.GetLayout());
if (default_layout.empty())
return false;
in.SetLayout(default_layout);
return true;
}

template <typename Workspace>
void RunHelper(OpNode &op_node, Workspace &ws) {
auto &output_desc = op_node.output_desc;
auto &op = *op_node.op;
output_desc.clear();
const auto &spec = op.GetSpec();
const auto &schema = spec.GetSchema();
SmallVector<int, 16> empty_layout_in_idxs;
for (int i = 0; i < spec.NumRegularInput(); i++) {
bool had_empty_layout = false;
if (ws.template InputIsType<CPUBackend>(i)) {
had_empty_layout = SetDefaultLayoutIfNeeded(ws.template InputRef<CPUBackend>(i), schema, i);
} else {
had_empty_layout = SetDefaultLayoutIfNeeded(ws.template InputRef<GPUBackend>(i), schema, i);
}
if (had_empty_layout)
empty_layout_in_idxs.push_back(i);
}

if (op.Setup(output_desc, ws)) {
DALI_ENFORCE(
static_cast<size_t>(ws.NumOutput()) == output_desc.size(),
Expand All @@ -369,6 +394,16 @@ class DLL_PUBLIC Executor : public ExecutorBase, public WorkspacePolicy, public
"always return false.");
}
op.Run(ws);

for (int i : empty_layout_in_idxs) {
if (ws.template InputIsType<CPUBackend>(i)) {
auto &in = ws.template InputRef<CPUBackend>(i);
in.SetLayout({});
} else {
auto &in = ws.template InputRef<GPUBackend>(i);
in.SetLayout({});
}
}
}
};

Expand Down

0 comments on commit 0362ac4

Please sign in to comment.