diff --git a/dali/kernels/reduce/reduce_axes_gpu_impl.cuh b/dali/kernels/reduce/reduce_axes_gpu_impl.cuh index 21f896a0246..0432360aeff 100644 --- a/dali/kernels/reduce/reduce_axes_gpu_impl.cuh +++ b/dali/kernels/reduce/reduce_axes_gpu_impl.cuh @@ -1,4 +1,4 @@ -// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -68,6 +68,31 @@ struct UniformPreprocessorBank { } // namespace reduce_impl + + +/** + * @brief This function is used when the reduction fills the output with the neutral element + * + * This function will apply postprocessing to the neutral value. + * + * @param pre_bank preprocessor bank, providing possibly distinct procssing + * per output sample + * @param post posptprocessing unary functor + */ +template +__device__ void ReduceNeutral(Out *out, Reduction reduce, int64_t n, + Postprocessor post) { + const int64_t blk_size = blockDim.x * blockDim.y; // no restriction on block size + const int64_t grid_stride = static_cast(gridDim.x) * blk_size; + const int flat_tid = threadIdx.x + threadIdx.y * blockDim.x; + int64_t base_idx = static_cast(blockIdx.x) * blk_size + flat_tid; + auto out_val = ConvertSat(reduce.template neutral()); + for (int64_t index = base_idx; index < n; index += grid_stride) { + out[index] = out_val; + } +} + + /** * @brief This function is used when the reduction is no-op (reduced extent is 1) * @@ -299,7 +324,9 @@ __device__ void ReduceInner(const ReduceSampleDesc &sample, Out *out = sample.out; const In *in = sample.in; - if (n_reduced == 1) { + if (n_reduced == 0) { + ReduceNeutral(out, reduce, n_outer, post); + } else if (n_reduced == 1) { ReduceNone(out, in, n_outer, pre_bank, post); } else if (n_reduced < 32 && sample.num_macroblocks == 1) { ReduceInnerSmall(out, in, n_outer, n_reduced, reduce, pre_bank, post); @@ -630,6 +657,56 @@ __global__ void ReduceMiddleKernel(const ReduceSampleDesc *samples, } } + + +template , + typename Postprocessor = identity> +__global__ void ReduceMiddleSmallKernel(const ReduceSampleDesc *samples, + Reduction reduce = {}, + const PreprocessorBank *pre = nullptr, + const Postprocessor *post = nullptr) { + auto sample = samples[blockIdx.y]; + + PreprocessorBank pre_bank = pre ? pre[blockIdx.y] : PreprocessorBank(); + Postprocessor postprocessor = post ? post[blockIdx.y] : Postprocessor(); + + ReduceMiddleSmall(sample, reduce, pre_bank, postprocessor); +} + +template , + typename Postprocessor = identity> +__global__ void ReduceMiddleLargeInnerSmallKernel(const ReduceSampleDesc *samples, + Reduction reduce = {}, + const PreprocessorBank *pre = nullptr, + const Postprocessor *post = nullptr) { + auto sample = samples[blockIdx.y]; + + PreprocessorBank pre_bank = pre ? pre[blockIdx.y] : PreprocessorBank(); + Postprocessor postprocessor = post ? post[blockIdx.y] : Postprocessor(); + + ReduceMiddleLargeInnerSmall(sample, reduce, pre_bank, postprocessor); +} + +template , + typename Postprocessor = identity> +__global__ void ReduceMiddleLargeInnerMediumKernel(const ReduceSampleDesc *samples, + Reduction reduce = {}, + const PreprocessorBank *pre = nullptr, + const Postprocessor *post = nullptr) { + auto sample = samples[blockIdx.y]; + + PreprocessorBank pre_bank = pre ? pre[blockIdx.y] : PreprocessorBank(); + Postprocessor postprocessor = post ? post[blockIdx.y] : Postprocessor(); + + ReduceMiddleLargeInnerMedium(sample, reduce, pre_bank, postprocessor); +} + } // namespace kernels } // namespace dali diff --git a/dali/kernels/reduce/reduce_gpu_impl.cuh b/dali/kernels/reduce/reduce_gpu_impl.cuh index febc604318c..7dd862a67fc 100644 --- a/dali/kernels/reduce/reduce_gpu_impl.cuh +++ b/dali/kernels/reduce/reduce_gpu_impl.cuh @@ -96,6 +96,7 @@ struct ReductionStage { vector shape; vector input_offsets, output_offsets; + vector sample_indices; int num_samples() const { return shape.size(); @@ -821,10 +822,14 @@ class ReduceImplGPU { } reduced[i] = sample_shape[axis]; outer[i] *= new_outer; - if (axis < in_dim - 1) - inner[i] /= new_outer * reduced[i]; - else + if (axis < in_dim - 1) { + if (new_outer * reduced[i] > 0) + inner[i] /= new_outer * reduced[i]; + else + inner[i] = 0; + } else { inner[i] = 1; + } } prev_axis = axis; @@ -894,7 +899,7 @@ class ReduceImplGPU { /** * @brief Launches a reduction stage in environment given by `ctx` */ - void LaunchStage(Context &ctx, const ReductionStage &stage) { + void LaunchStage(Context &ctx, ReductionStage &stage) { ctx.work_area.BeginStage(stage.index); VALUE_SWITCH(stage.kind, kind, ( ReductionKind::All, @@ -927,7 +932,7 @@ class ReduceImplGPU { * @return Host-side parameter buffer containing the input pointers. */ template - const StageIn *const *InputPtrs(Context &ctx, const ReductionStage &stage) const { + const StageIn *const *InputPtrs(Context &ctx, ReductionStage &stage) const { WorkArea &wa = ctx.work_area; auto *ptrs = wa.ParamBuffer(stage.num_samples()); if (stage.index > 0) { @@ -1004,8 +1009,26 @@ class ReduceImplGPU { return samples; } + /** + * @brief Permutes a sequence in place + * + * @remarks The function will return incorrect result or hang if `idx` contains repetitions. + */ + template + void permute_in_place(Seq &inout, Indices &&idx) { + using index_type = std::remove_reference_t; + using size_type = decltype(dali::size(idx)); + for (size_type i = 0, n = dali::size(idx); i < n; i++) { + size_type src_idx = idx[i]; + while (src_idx < i) + src_idx = idx[src_idx]; + if (src_idx != i) + std::swap(inout[i], inout[src_idx]); + } + } + template - void LaunchStage(Context &ctx, const ReductionStage &stage, + void LaunchStage(Context &ctx, ReductionStage &stage, ReductionKindTag) { assert(!is_first || ctx.input.is_contiguous()); WorkArea &wa = ctx.work_area; @@ -1029,7 +1052,7 @@ class ReduceImplGPU { } template - void LaunchStage(Context &ctx, const ReductionStage &stage, + void LaunchStage(Context &ctx, ReductionStage &stage, ReductionKindTag) { assert(!is_last || ctx.output.is_contiguous()); WorkArea &wa = ctx.work_area; @@ -1065,7 +1088,7 @@ class ReduceImplGPU { template - void LaunchStage(Context &ctx, const ReductionStage &stage, + void LaunchStage(Context &ctx, ReductionStage &stage, ReductionKindTag) { assert(!is_last || ctx.output.is_contiguous()); assert(!is_first && "Block reduction is never the first stage"); @@ -1095,12 +1118,12 @@ class ReduceImplGPU { } template - void LaunchStage(Context &ctx, const ReductionStage &stage, + void LaunchStage(Context &ctx, ReductionStage &stage, ReductionKindTag) { using SampleDesc = ReduceSampleDesc; WorkArea &wa = ctx.work_area; + int num_samples = stage.num_samples(); SampleDesc *cpu_samples = PrepareSampleDescs(ctx, stage); - auto *pre = GetPreprocessorBanks(wa, stage.axis); auto *post = GetPostprocessors(wa); @@ -1113,8 +1136,8 @@ class ReduceImplGPU { wa.CopyParamsToDevice(ctx.stream); dim3 block(32, max_block_size / 32); - int gridx = std::max(32, 512/stage.num_samples()); - dim3 grid(gridx, stage.num_samples()); + int gridx = std::max(32, 512/num_samples); + dim3 grid(gridx, num_samples); SampleDesc *gpu_samples = wa.GetDeviceParam(cpu_samples); auto *gpu_pre = wa.GetDeviceParam(pre); @@ -1127,40 +1150,126 @@ class ReduceImplGPU { } template - void LaunchStage(Context &ctx, const ReductionStage &stage, + void LaunchStage(Context &ctx, ReductionStage &stage, ReductionKindTag) { using SampleDesc = ReduceSampleDesc; WorkArea &wa = ctx.work_area; SampleDesc *cpu_samples = PrepareSampleDescs(ctx, stage); + // There are three cases, depending on the sizes of the inner and the reduced dimension. + // We separate the sample into three bins by running a stable partiioning on the samples. + + int num_samples = stage.num_samples(); + + auto &indices = stage.sample_indices; + indices.resize(num_samples); + std::iota(indices.begin(), indices.end(), 0); + + auto middle_small_end = std::stable_partition( + indices.begin(), + indices.end(), + [&](int i) { + return cpu_samples[i].n_reduced < 1024 && cpu_samples[i].num_macroblocks == 1; + }); + + auto middle_large_inner_small_end = std::stable_partition( + middle_small_end, + indices.end(), + [&](int i) { + return cpu_samples[i].n_inner < 32; + }); + + int num_middle_small = middle_small_end - indices.begin(); + int num_middle_large_inner_small = middle_large_inner_small_end - middle_small_end; + int num_middle_large_inner_medium = indices.end() - middle_large_inner_small_end; + auto *pre = GetPreprocessorBanks(wa, stage.axis); auto *post = GetPostprocessors(wa); - using pre_bank_t = std::remove_cv_t>;; - using post_t = std::remove_cv_t>;; - using red_t = std::remove_reference_t; - - int max_block_size = std::min(1024, MaxThreadsPerBlock( - ReduceMiddleKernel)); + permute_in_place(cpu_samples, indices); + permute_in_place(pre, indices); + permute_in_place(post, indices); wa.CopyParamsToDevice(ctx.stream); - dim3 block(32, max_block_size / 32); - int gridx = std::max(32, 512/stage.num_samples()); - dim3 grid(gridx, stage.num_samples()); - const int shm_size = 0x8000; // 32 kB shared mem SampleDesc *gpu_samples = wa.GetDeviceParam(cpu_samples); auto *gpu_pre = wa.GetDeviceParam(pre); auto *gpu_post = wa.GetDeviceParam(post); - ReduceMiddleKernel<<>>( - gpu_samples, This().GetReduction(), gpu_pre, gpu_post); + using pre_bank_t = std::remove_cv_t>;; + using post_t = std::remove_cv_t>;; + using red_t = std::remove_reference_t; - CUDA_CALL(cudaGetLastError()); + auto launch_params = [&](auto kernel, int nsamples, int shm_size) { + int preferred_block_size = 256; + int preferred_grid_size; // unused + CUDA_CALL(cudaOccupancyMaxPotentialBlockSize( + &preferred_grid_size, + &preferred_block_size, + kernel, + shm_size)); + + dim3 block(32, preferred_block_size / 32); + int gridx = std::max(32, 512/nsamples); + dim3 grid(gridx, nsamples); + return std::make_pair(grid, block); + }; + + dim3 grid, block; + int sample_offset = 0; + + // MiddleSmall + if (num_middle_small) { + std::tie(grid, block) = launch_params( + ReduceMiddleSmallKernel, + num_middle_small, 0); + + ReduceMiddleSmallKernel<<>>( + gpu_samples + sample_offset, + This().GetReduction(), + gpu_pre ? gpu_pre + sample_offset : nullptr, + gpu_post ? gpu_post + sample_offset : nullptr); + + sample_offset += num_middle_small; + } + + int shm_size = 0x8000; + + // MiddleLargeInnerSmall + + if (num_middle_large_inner_small) { + std::tie(grid, block) = launch_params( + ReduceMiddleLargeInnerSmallKernel, + num_middle_large_inner_small, shm_size); + + ReduceMiddleLargeInnerSmallKernel<<>>( + gpu_samples + sample_offset, + This().GetReduction(), + gpu_pre ? gpu_pre + sample_offset : nullptr, + gpu_post ? gpu_post + sample_offset : nullptr); + + sample_offset += num_middle_large_inner_small; + } + + // MiddleLargeInnerMedium + + if (num_middle_large_inner_medium) { + std::tie(grid, block) = launch_params( + ReduceMiddleLargeInnerMediumKernel, + num_middle_large_inner_medium, shm_size); + + ReduceMiddleLargeInnerMediumKernel<<>>( + gpu_samples + sample_offset, + This().GetReduction(), + gpu_pre ? gpu_pre + sample_offset : nullptr, + gpu_post ? gpu_post + sample_offset : nullptr); + } + + assert((sample_offset + num_middle_large_inner_medium) == num_samples); } template - void LaunchStage(Context &ctx, const ReductionStage &stage, + void LaunchStage(Context &ctx, ReductionStage &stage, ReductionKindTag) { assert(is_last); assert(!stage.shape.empty()); @@ -1196,7 +1305,7 @@ class ReduceImplGPU { } template - void LaunchStage(Context &ctx, const ReductionStage &stage, + void LaunchStage(Context &ctx, ReductionStage &stage, ReductionKindTag) { assert(is_last); assert(!stage.shape.empty()); diff --git a/dali/kernels/reduce/reduce_gpu_test.cc b/dali/kernels/reduce/reduce_gpu_test.cc index 04271f5a549..8e3fb86c2f1 100644 --- a/dali/kernels/reduce/reduce_gpu_test.cc +++ b/dali/kernels/reduce/reduce_gpu_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -47,5 +47,26 @@ TEST(SumGPU, SplitStageBatch) { } } +TEST(SumGPU, ReduceMiddleRandomBench) { + TensorListShape<> in_shape = {{ + { 32, 3, 64000 }, + { 15, 3, 128000 }, + { 72000, 3, 7 } + }}; + TensorListShape<> ref_out_shape = {{ + TensorShape<>{3} + }}; + int axes[] = { 0, 2 }; + + testing::ReductionKernelTest, uint64_t, uint8_t> test; + for (int iter = 0; iter < 3; iter++) { + test.Setup(in_shape, ref_out_shape, make_span(axes), false, true); + test.FillData(0, 255); + test.Run(); + RefReduce(test.ref.cpu(), test.in.cpu(), make_span(axes), false, true, reductions::sum()); + test.Check(); + } +} + } // namespace kernels } // namespace dali diff --git a/dali/test/python/operator/test_reduce.py b/dali/test/python/operator/test_reduce.py index 270a7c5a77f..750bf966a92 100644 --- a/dali/test/python/operator/test_reduce.py +++ b/dali/test/python/operator/test_reduce.py @@ -14,6 +14,7 @@ import nvidia.dali.fn as fn from nvidia.dali.pipeline import Pipeline +from nose.tools import nottest import numpy as np @@ -385,3 +386,40 @@ def get_batch(): for reduction in reductions_with_mean_input: yield run_reduce_with_layout_with_mean_input, batch_size, get_batch, reduction, \ axes, axis_names, batch_fn + + +@nottest +def _test_reduce_large_data(rank, axes, device): + batch_size = 16 + num_batches = 2 + data = [] + for _ in range(num_batches): + batch = [] + for _ in range(batch_size): + size = np.random.randint(1, 128, size=rank) + batch.append(np.random.random(size=size).astype(np.float32)) + data.append(batch) + + pipe = Pipeline(batch_size=batch_size, num_threads=4, device_id=0 if device == 'gpu' else None) + input = fn.external_source(data, cycle=True, device=device) + reduced = fn.reductions.sum(input, axes=axes) + pipe.set_outputs(reduced) + pipe.build() + + for b, batch in enumerate(data): + out, = pipe.run() + if device == 'gpu': + out = out.as_cpu() + for i in range(batch_size): + ref = np.sum(batch[i], axis=axes) + assert np.allclose(out[i], ref, 1e-5, 1e-5) + + +def test_reduce_large_data(): + np.random.seed(1234) + for device in ['gpu']: + for rank in range(1, 4): + for axis_mask in range(1, 2**rank): + axes = tuple(filter(lambda x: x >= 0, + (i if axis_mask & (1 << i) else -1 for i in range(rank)))) + yield _test_reduce_large_data, rank, axes, device