Skip to content

Commit

Permalink
Split middle reduction kernels.
Browse files Browse the repository at this point in the history
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
  • Loading branch information
mzient committed Oct 24, 2022
1 parent 7eda534 commit c4bbdc5
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 31 deletions.
81 changes: 79 additions & 2 deletions dali/kernels/reduce/reduce_axes_gpu_impl.cuh
@@ -1,4 +1,4 @@
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,6 +68,31 @@ struct UniformPreprocessorBank {
} // namespace reduce_impl




/**
* @brief This function is used when the reduction fills the output with the neutral element
*
* This function will apply postprocessing to the neutral value.
*
* @param pre_bank preprocessor bank, providing possibly distinct procssing
* per output sample
* @param post posptprocessing unary functor
*/
template <typename Acc, typename Out, typename Reduction, typename Postprocessor>
__device__ void ReduceNeutral(Out *out, Reduction reduce, int64_t n,
Postprocessor post) {
const int64_t blk_size = blockDim.x * blockDim.y; // no restriction on block size
const int64_t grid_stride = static_cast<int64_t>(gridDim.x) * blk_size;
const int flat_tid = threadIdx.x + threadIdx.y * blockDim.x;
int64_t base_idx = static_cast<int64_t>(blockIdx.x) * blk_size + flat_tid;
auto out_val = ConvertSat<Out>(reduce.template neutral<Acc>());
for (int64_t index = base_idx; index < n; index += grid_stride) {
out[index] = out_val;
}
}


/**
* @brief This function is used when the reduction is no-op (reduced extent is 1)
*
Expand Down Expand Up @@ -299,7 +324,9 @@ __device__ void ReduceInner(const ReduceSampleDesc<Out, In> &sample,
Out *out = sample.out;
const In *in = sample.in;

if (n_reduced == 1) {
if (n_reduced == 0) {
ReduceNeutral<Acc>(out, reduce, n_outer, post);
} else if (n_reduced == 1) {
ReduceNone(out, in, n_outer, pre_bank, post);
} else if (n_reduced < 32 && sample.num_macroblocks == 1) {
ReduceInnerSmall<Acc>(out, in, n_outer, n_reduced, reduce, pre_bank, post);
Expand Down Expand Up @@ -630,6 +657,56 @@ __global__ void ReduceMiddleKernel(const ReduceSampleDesc<Out, In> *samples,
}
}



template <typename Acc, typename Out, typename In,
typename Reduction = reductions::sum,
typename PreprocessorBank = reduce_impl::IdentityPreprocessor<2>,
typename Postprocessor = identity>
__global__ void ReduceMiddleSmallKernel(const ReduceSampleDesc<Out, In> *samples,
Reduction reduce = {},
const PreprocessorBank *pre = nullptr,
const Postprocessor *post = nullptr) {
auto sample = samples[blockIdx.y];

PreprocessorBank pre_bank = pre ? pre[blockIdx.y] : PreprocessorBank();
Postprocessor postprocessor = post ? post[blockIdx.y] : Postprocessor();

ReduceMiddleSmall<Acc>(sample, reduce, pre_bank, postprocessor);
}

template <typename Acc, typename Out, typename In,
typename Reduction = reductions::sum,
typename PreprocessorBank = reduce_impl::IdentityPreprocessor<2>,
typename Postprocessor = identity>
__global__ void ReduceMiddleLargeInnerSmallKernel(const ReduceSampleDesc<Out, In> *samples,
Reduction reduce = {},
const PreprocessorBank *pre = nullptr,
const Postprocessor *post = nullptr) {
auto sample = samples[blockIdx.y];

PreprocessorBank pre_bank = pre ? pre[blockIdx.y] : PreprocessorBank();
Postprocessor postprocessor = post ? post[blockIdx.y] : Postprocessor();

ReduceMiddleLargeInnerSmall<Acc>(sample, reduce, pre_bank, postprocessor);
}

template <typename Acc, typename Out, typename In,
typename Reduction = reductions::sum,
typename PreprocessorBank = reduce_impl::IdentityPreprocessor<2>,
typename Postprocessor = identity>
__global__ void ReduceMiddleLargeInnerMediumKernel(const ReduceSampleDesc<Out, In> *samples,
Reduction reduce = {},
const PreprocessorBank *pre = nullptr,
const Postprocessor *post = nullptr) {
auto sample = samples[blockIdx.y];

PreprocessorBank pre_bank = pre ? pre[blockIdx.y] : PreprocessorBank();
Postprocessor postprocessor = post ? post[blockIdx.y] : Postprocessor();

ReduceMiddleLargeInnerMedium<Acc>(sample, reduce, pre_bank, postprocessor);
}

} // namespace kernels
} // namespace dali

Expand Down
165 changes: 137 additions & 28 deletions dali/kernels/reduce/reduce_gpu_impl.cuh
Expand Up @@ -96,6 +96,7 @@ struct ReductionStage {

vector<ReductionShape> shape;
vector<int64_t> input_offsets, output_offsets;
vector<int> sample_indices;

int num_samples() const {
return shape.size();
Expand Down Expand Up @@ -821,10 +822,14 @@ class ReduceImplGPU {
}
reduced[i] = sample_shape[axis];
outer[i] *= new_outer;
if (axis < in_dim - 1)
inner[i] /= new_outer * reduced[i];
else
if (axis < in_dim - 1) {
if (new_outer * reduced[i] > 0)
inner[i] /= new_outer * reduced[i];
else
inner[i] = 0;
} else {
inner[i] = 1;
}
}
prev_axis = axis;

Expand Down Expand Up @@ -894,7 +899,7 @@ class ReduceImplGPU {
/**
* @brief Launches a reduction stage in environment given by `ctx`
*/
void LaunchStage(Context &ctx, const ReductionStage &stage) {
void LaunchStage(Context &ctx, ReductionStage &stage) {
ctx.work_area.BeginStage(stage.index);
VALUE_SWITCH(stage.kind, kind, (
ReductionKind::All,
Expand Down Expand Up @@ -927,7 +932,7 @@ class ReduceImplGPU {
* @return Host-side parameter buffer containing the input pointers.
*/
template <typename StageIn>
const StageIn *const *InputPtrs(Context &ctx, const ReductionStage &stage) const {
const StageIn *const *InputPtrs(Context &ctx, ReductionStage &stage) const {
WorkArea &wa = ctx.work_area;
auto *ptrs = wa.ParamBuffer<const StageIn*>(stage.num_samples());
if (stage.index > 0) {
Expand Down Expand Up @@ -1004,8 +1009,26 @@ class ReduceImplGPU {
return samples;
}

/**
* @brief Permutes a sequence in place
*
* @remarks The function will return incorrect result or hang if `idx` contains repetitions.
*/
template <typename Seq, typename Indices>
void permute_in_place(Seq &inout, Indices &&idx) {
using index_type = std::remove_reference_t<decltype(idx[0])>;
using size_type = decltype(dali::size(idx));
for (size_type i = 0, n = dali::size(idx); i < n; i++) {
size_type src_idx = idx[i];
while (src_idx < i)
src_idx = idx[src_idx];
if (src_idx != i)
std::swap(inout[i], inout[src_idx]);
}
}

template <typename StageOut, typename StageIn, bool is_first, bool is_last>
void LaunchStage(Context &ctx, const ReductionStage &stage,
void LaunchStage(Context &ctx, ReductionStage &stage,
ReductionKindTag<ReductionKind::All>) {
assert(!is_first || ctx.input.is_contiguous());
WorkArea &wa = ctx.work_area;
Expand All @@ -1029,7 +1052,7 @@ class ReduceImplGPU {
}

template <typename StageOut, typename StageIn, bool is_first, bool is_last>
void LaunchStage(Context &ctx, const ReductionStage &stage,
void LaunchStage(Context &ctx, ReductionStage &stage,
ReductionKindTag<ReductionKind::Sample>) {
assert(!is_last || ctx.output.is_contiguous());
WorkArea &wa = ctx.work_area;
Expand Down Expand Up @@ -1065,7 +1088,7 @@ class ReduceImplGPU {


template <typename StageOut, typename StageIn, bool is_first, bool is_last>
void LaunchStage(Context &ctx, const ReductionStage &stage,
void LaunchStage(Context &ctx, ReductionStage &stage,
ReductionKindTag<ReductionKind::Block>) {
assert(!is_last || ctx.output.is_contiguous());
assert(!is_first && "Block reduction is never the first stage");
Expand Down Expand Up @@ -1095,12 +1118,12 @@ class ReduceImplGPU {
}

template <typename StageOut, typename StageIn, bool is_first, bool is_last>
void LaunchStage(Context &ctx, const ReductionStage &stage,
void LaunchStage(Context &ctx, ReductionStage &stage,
ReductionKindTag<ReductionKind::Inner>) {
using SampleDesc = ReduceSampleDesc<StageOut, StageIn>;
WorkArea &wa = ctx.work_area;
int num_samples = stage.num_samples();
SampleDesc *cpu_samples = PrepareSampleDescs<StageOut, StageIn>(ctx, stage);

auto *pre = GetPreprocessorBanks<is_first, 1>(wa, stage.axis);
auto *post = GetPostprocessors<is_last>(wa);

Expand All @@ -1113,8 +1136,8 @@ class ReduceImplGPU {

wa.CopyParamsToDevice(ctx.stream);
dim3 block(32, max_block_size / 32);
int gridx = std::max(32, 512/stage.num_samples());
dim3 grid(gridx, stage.num_samples());
int gridx = std::max(32, 512/num_samples);
dim3 grid(gridx, num_samples);

SampleDesc *gpu_samples = wa.GetDeviceParam(cpu_samples);
auto *gpu_pre = wa.GetDeviceParam(pre);
Expand All @@ -1127,40 +1150,126 @@ class ReduceImplGPU {
}

template <typename StageOut, typename StageIn, bool is_first, bool is_last>
void LaunchStage(Context &ctx, const ReductionStage &stage,
void LaunchStage(Context &ctx, ReductionStage &stage,
ReductionKindTag<ReductionKind::Middle>) {
using SampleDesc = ReduceSampleDesc<StageOut, StageIn>;
WorkArea &wa = ctx.work_area;
SampleDesc *cpu_samples = PrepareSampleDescs<StageOut, StageIn>(ctx, stage);

// There are three cases, depending on the sizes of the inner and the reduced dimension.
// We separate the sample into three bins by running a stable partiioning on the samples.

int num_samples = stage.num_samples();

auto &indices = stage.sample_indices;
indices.resize(num_samples);
std::iota(indices.begin(), indices.end(), 0);

auto middle_small_end = std::stable_partition(
indices.begin(),
indices.end(),
[&](int i) {
return cpu_samples[i].n_reduced < 1024 && cpu_samples[i].num_macroblocks == 1;
});

auto middle_large_inner_small_end = std::stable_partition(
middle_small_end,
indices.end(),
[&](int i) {
return cpu_samples[i].n_inner < 32;
});

int num_middle_small = middle_small_end - indices.begin();
int num_middle_large_inner_small = middle_large_inner_small_end - middle_small_end;
int num_middle_large_inner_medium = indices.end() - middle_large_inner_small_end;

auto *pre = GetPreprocessorBanks<is_first, 2>(wa, stage.axis);
auto *post = GetPostprocessors<is_last>(wa);

using pre_bank_t = std::remove_cv_t<std::remove_reference_t<decltype(*pre)>>;;
using post_t = std::remove_cv_t<std::remove_reference_t<decltype(*post)>>;;
using red_t = std::remove_reference_t<decltype(This().GetReduction())>;

int max_block_size = std::min(1024, MaxThreadsPerBlock(
ReduceMiddleKernel<Acc, StageOut, StageIn, red_t, pre_bank_t, post_t>));
permute_in_place(cpu_samples, indices);
permute_in_place(pre, indices);
permute_in_place(post, indices);

wa.CopyParamsToDevice(ctx.stream);
dim3 block(32, max_block_size / 32);
int gridx = std::max(32, 512/stage.num_samples());
dim3 grid(gridx, stage.num_samples());
const int shm_size = 0x8000; // 32 kB shared mem

SampleDesc *gpu_samples = wa.GetDeviceParam(cpu_samples);
auto *gpu_pre = wa.GetDeviceParam(pre);
auto *gpu_post = wa.GetDeviceParam(post);

ReduceMiddleKernel<Acc><<<grid, block, shm_size, ctx.stream>>>(
gpu_samples, This().GetReduction(), gpu_pre, gpu_post);
using pre_bank_t = std::remove_cv_t<std::remove_reference_t<decltype(*pre)>>;;
using post_t = std::remove_cv_t<std::remove_reference_t<decltype(*post)>>;;
using red_t = std::remove_reference_t<decltype(This().GetReduction())>;

CUDA_CALL(cudaGetLastError());
auto launch_params = [&](auto kernel, int nsamples, int shm_size) {
int preferred_block_size = 256;
int preferred_grid_size; // unused
CUDA_CALL(cudaOccupancyMaxPotentialBlockSize(
&preferred_grid_size,
&preferred_block_size,
kernel,
shm_size));

dim3 block(32, preferred_block_size / 32);
int gridx = std::max(32, 512/nsamples);
dim3 grid(gridx, nsamples);
return std::make_pair(grid, block);
};

dim3 grid, block;
int sample_offset = 0;

// MiddleSmall
if (num_middle_small) {
std::tie(grid, block) = launch_params(
ReduceMiddleSmallKernel<Acc, Out, In, red_t, pre_bank_t, post_t>,
num_middle_small, 0);

ReduceMiddleSmallKernel<Acc><<<grid, block, 0, ctx.stream>>>(
gpu_samples + sample_offset,
This().GetReduction(),
gpu_pre ? gpu_pre + sample_offset : nullptr,
gpu_post ? gpu_post + sample_offset : nullptr);

sample_offset += num_middle_small;
}

int shm_size = 0x8000;

// MiddleLargeInnerSmall

if (num_middle_large_inner_small) {
std::tie(grid, block) = launch_params(
ReduceMiddleLargeInnerSmallKernel<Acc, Out, In, red_t, pre_bank_t, post_t>,
num_middle_large_inner_small, shm_size);

ReduceMiddleLargeInnerSmallKernel<Acc><<<grid, block, shm_size, ctx.stream>>>(
gpu_samples + sample_offset,
This().GetReduction(),
gpu_pre ? gpu_pre + sample_offset : nullptr,
gpu_post ? gpu_post + sample_offset : nullptr);

sample_offset += num_middle_large_inner_small;
}

// MiddleLargeInnerMedium

if (num_middle_large_inner_medium) {
std::tie(grid, block) = launch_params(
ReduceMiddleLargeInnerMediumKernel<Acc, Out, In, red_t, pre_bank_t, post_t>,
num_middle_large_inner_medium, shm_size);

ReduceMiddleLargeInnerMediumKernel<Acc><<<grid, block, shm_size, ctx.stream>>>(
gpu_samples + sample_offset,
This().GetReduction(),
gpu_pre ? gpu_pre + sample_offset : nullptr,
gpu_post ? gpu_post + sample_offset : nullptr);
}

assert((sample_offset + num_middle_large_inner_medium) == num_samples);
}

template <typename StageOut, typename StageIn, bool is_first, bool is_last>
void LaunchStage(Context &ctx, const ReductionStage &stage,
void LaunchStage(Context &ctx, ReductionStage &stage,
ReductionKindTag<ReductionKind::Fold>) {
assert(is_last);
assert(!stage.shape.empty());
Expand Down Expand Up @@ -1196,7 +1305,7 @@ class ReduceImplGPU {
}

template <typename StageOut, typename StageIn, bool is_first, bool is_last>
void LaunchStage(Context &ctx, const ReductionStage &stage,
void LaunchStage(Context &ctx, ReductionStage &stage,
ReductionKindTag<ReductionKind::None>) {
assert(is_last);
assert(!stage.shape.empty());
Expand Down

0 comments on commit c4bbdc5

Please sign in to comment.