Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimized variant of CMN for HWC to HWC pad FP16 case #4993

Merged
merged 3 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 250 additions & 9 deletions dali/kernels/slice/slice_hwc2chw_normalize_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,107 @@ __device__ __forceinline__ Tile *slice_load_linear_tile(
return tile;
}

/**
* @brief Load the slices of linear tile into planar smem buffers.
*
* During the loading the values are distributed into separate planes in smem (keeping the same
* sequential XY coordinates/offsets). Allows for faster access when building padded HWC output.
* Each smem plane must hold kBlockSize / kStaticChannels elements.
*
* @tparam kBlockSize Tile size
* @tparam kStaticChannels Number of input channels
* @tparam Tile Type of the data kept after loading in the smem tile.
* @tparam Out Output data type
* @tparam In Input data type
* @tparam kLoadAlign - Alignment (in bytes) of the main loop.
* @param tile Shared memory where to load the data.
* @param sample Sample description
* @return Tile * - the pointer to the smem where the start of the loaded data is.
*/
template <int kBlockSize, int kStaticChannels, typename Tile, typename Out, typename In,
int kLoadAlign = 4>
__device__ __forceinline__ void load_planar_tile(Tile tile[][kBlockSize / kStaticChannels],
const Hwc2HwcChwSampleDesc<Out, In> sample) {
static_assert(std::is_same_v<In, uint8_t>, "Only uint8_t types allowed now.");
static_assert(kStaticChannels == 3, "Only 3 input channels allowed now.");
static_assert(kLoadAlign % 4 == 0, "The loading alignment should be divisible by 4.");

int64_t start_x = (blockIdx.x - sample.first_block) * kBlockSize;
int64_t end_x = ::min(start_x + kBlockSize, sample.sample_size);

auto in_start = reinterpret_cast<std::uintptr_t>(sample.in + start_x);
auto aligned_in_start = align_up(in_start, kLoadAlign);
uint32_t bytes_to_alignment = ::min(aligned_in_start - in_start, end_x - start_x);

const In *prologue_in = sample.in + start_x;

const uchar4 *aligned_in_char4 =
reinterpret_cast<const uchar4 *>(sample.in + start_x + bytes_to_alignment);

// The tiles are multiple of 3, so we are always reading from the start of the pixel.

fast_div<uint32_t> channel(kStaticChannels);
// prologue
for (uint32_t idx = threadIdx.x; idx < bytes_to_alignment; idx += blockDim.x) {
uint32_t xy, c;
xy = div_mod(c, idx, channel);
tile[c][xy] = prologue_in[idx];
}

// this might be 0, as the prologue may be the full extend of the tile
uint32_t left_after_prologue = end_x - start_x - bytes_to_alignment;


// We read 4 values in each iteration
uint32_t main_loop_length = left_after_prologue >> 2;

// main loop: aligned load and unpacking
for (uint32_t idx = threadIdx.x; idx < main_loop_length; idx += blockDim.x) {
uint32_t flat_idx = idx * 4 + bytes_to_alignment;
uint32_t xy, c;
xy = div_mod(c, flat_idx, channel);
uchar4 in = aligned_in_char4[idx];

tile[c][xy] = in.x;

c++;
if (c == kStaticChannels) {
c = 0;
xy++;
}
tile[c][xy] = in.y;

c++;
if (c == kStaticChannels) {
c = 0;
xy++;
}
tile[c][xy] = in.z;


c++;
if (c == kStaticChannels) {
c = 0;
xy++;
}
tile[c][xy] = in.w;
}

uint32_t processed_in_main = left_after_prologue & -4; // equivalent to (x / 4) * 4
uint32_t left_after_main = left_after_prologue - processed_in_main;

// epilogue
const In *epilogue_in = reinterpret_cast<const In *>(aligned_in_char4 + main_loop_length);

for (uint32_t idx = threadIdx.x; idx < left_after_main; idx++) {
uint32_t flat_idx = processed_in_main + bytes_to_alignment + idx;
uint32_t xy, c;
xy = div_mod(c, flat_idx, channel);
tile[c][xy] = epilogue_in[idx];
}
}


/** @} */ // end of Hwc2HwcChwLoad


Expand Down Expand Up @@ -401,6 +502,94 @@ __device__ __forceinline__ void store_hwc(Tile *tile, const Hwc2HwcChwSampleDesc
}
}

/**
* @brief Store a tile of smem that is kept as planes in the HWC format.
*
* This version is specialized for uint8_t inputs and fp16 outputs + padding from 3 to 4 channels.
* The output samples are expected to be aligned to at least 4-bytes allowing for vectorized
* stores of __half2.
* @tparam Compute Type to conduct computations in.
* TODO(klecki): vectorized __half2 can be considered, float is ok.
* @tparam Tile smem tile storage type
*/
template <int kBlockSize, int kStaticChannels, bool enable_mirror, typename Compute, typename Tile>
__device__ __forceinline__ void store_planar_hwc_pad(
Tile tile[][kBlockSize / kStaticChannels],
const Hwc2HwcChwSampleDesc<float16, uint8_t> sample) {
constexpr int kOutChannels = kStaticChannels + 1;

int64_t start_x = (blockIdx.x - sample.first_block) * kBlockSize;
int64_t end_x = ::min(start_x + kBlockSize, sample.sample_size);

const auto *__restrict__ fill_values = static_cast<const float16 *>(sample.fill_values);

// Preload the norm values so they are accessed via registers and not from gmem via pointer.
Compute norm_mul[kOutChannels], norm_add[kOutChannels];

#pragma unroll kStaticChannels
for (int c = 0; c < kStaticChannels; c++) {
norm_mul[c] = sample.norm_mul[c];
norm_add[c] = sample.norm_add[c];
}

// put the fill value so it will be produced as a result of FMA
norm_mul[3] = 0;
norm_add[3] = sample.fill_values[3];

// Assuming all samples are padded
int64_t block_4 = (kBlockSize / kStaticChannels) * kOutChannels;
int64_t sample_size_4 = (sample.sample_size / kStaticChannels) * kOutChannels;
int64_t start_x_padded = static_cast<int64_t>(blockIdx.x - sample.first_block) * block_4;
int64_t end_x_padded = ::min(start_x_padded + block_4, sample_size_4);


// TODO(klecki) in the version without mirror, we can keep one offset, as we can start the
// output pointer at the output tile.
auto *out_aligned = sample.out;
auto *out_h2 = reinterpret_cast<__half2 *>(sample.out);
uint32_t to_write = end_x_padded - start_x_padded;

// loop is divided by two as we write two elements in each thread
for (uint32_t base_x = threadIdx.x; base_x < to_write / 2; base_x += blockDim.x) {
int base_offset = base_x / 2;
int c = base_x & 1;

int64_t out_offset;
if constexpr (enable_mirror) {
if (sample.flip_x) {
int64_t idx = start_x_padded + base_x * 2;
int y = idx / (sample.W * kOutChannels);
int xc = idx - (int64_t)y * sample.W * kOutChannels;
int x = xc / kOutChannels;
int target_x = sample.W - 1 - x;
// basically we divide the out_offset by two, The `c` is either 0 or 1.
out_offset = (int64_t)y * sample.W * (kOutChannels / 2) + target_x * (kOutChannels / 2) + c;
} else {
out_offset = start_x_padded / 2 + base_x;
}
} else {
out_offset = start_x_padded / 2 + base_x;
}

if (c == 0) {
Compute fpin0 = tile[0][base_offset];
Compute fpin1 = tile[1][base_offset];

Compute fpout0 = fmaf(fpin0, norm_mul[0], norm_add[0]);
Compute fpout1 = fmaf(fpin1, norm_mul[1], norm_add[1]);
out_h2[out_offset] = make_half2(ConvertSat<float16>(fpout0), ConvertSat<float16>(fpout1));
} else {
Compute fpin0 = tile[2][base_offset];

Compute fpout0 = fmaf(fpin0, norm_mul[2], norm_add[2]);
// With more generic implementation, we could do the FMA for this value as well, but we
// need to just pad it here.
Compute fpout1 = norm_add[3];
out_h2[out_offset] = make_half2(ConvertSat<float16>(fpout0), ConvertSat<float16>(fpout1));
}
}
}


/** @} */ // end of Hwc2HwcChwStore

Expand Down Expand Up @@ -523,6 +712,34 @@ __global__ void SliceHwc2HwcNormalize(const Hwc2HwcChwSampleDesc<Out, In> *sampl
store_hwc<kBlockSize, kStaticChannels, enable_mirror, enable_pad, Out>(loaded_tile, sample);
}

/**
* @brief Hwc2Hwc Normalize [Mirror-x] Pad-channel-always kernel for FP16.
*
* This kernel utilizes 4-byte reads and writes. The smem intermediate tile uses planar layout,
* for better access to the image values during writing of the output.
* The output samples are assumed to be aligned to the address that is multiple of 4,
* thanks to the padding performed to 4 channels, it holds for every batch that is laid out
* contiguously in memory with aligned start. This holds for forseeable future in DALI.
*/
template <typename Out, typename In, bool enable_mirror, int kBlockSize, int kStaticChannels>
__global__ void Hwc2HwcNormalizePadFp16(const Hwc2HwcChwSampleDesc<Out, In> *samples,
uint32_t *first_blocks, uint32_t num_samples) {
static_assert(std::is_same<In, uint8_t>::value, "Only uint8_t supported as input");

constexpr int kOutChannels = kStaticChannels + 1;

int sample_idx = FindSampleIdx(first_blocks, num_samples);
const auto sample = samples[sample_idx];

__shared__ float tile[kStaticChannels][kBlockSize / kStaticChannels];
load_planar_tile<kBlockSize, kStaticChannels>(tile, sample);

__syncthreads();

store_planar_hwc_pad<kBlockSize, kStaticChannels, enable_mirror, float>(tile, sample);
}


/** @} */ // end of Hwc2HwcChw

template <typename Out>
Expand Down Expand Up @@ -662,6 +879,12 @@ void SliceHwc2HwcChwNormalizeGPU<Out>::Run(KernelContext &ctx,
bool need_pad = out_nchannels_ != nchannels_;
bool need_crop_x = false;
bool need_flip_x = false;
// Check if all the outputs are aligned to 4 bytes, used by the specialized FP16 PAD HWC -> HWC
// implementation. With the current state of DALI, the start of output allocation is aligned
// (to even higher power of two), and all the samples have length that is multiple of 4 (padded to
// 4 channels), that is if they are in contiguous allocation, all output samples are still aligned
// to a multiple of 4.
bool outputs_aligned_4 = true;

uint32_t offset_blk = 0;
int nonempty_samples = 0;
Expand All @@ -682,6 +905,9 @@ void SliceHwc2HwcChwNormalizeGPU<Out>::Run(KernelContext &ctx,
auto &first_block = first_blocks_cpu[nonempty_samples++];
sample_desc.in = in_sample.data;
sample_desc.out = out.tensor_data(sample_id);
if (reinterpret_cast<std::uintptr_t>(sample_desc.out) % 4) {
outputs_aligned_4 = false;
}

first_block = offset_blk;
sample_desc.first_block = offset_blk;
Expand Down Expand Up @@ -750,30 +976,45 @@ void SliceHwc2HwcChwNormalizeGPU<Out>::Run(KernelContext &ctx,
}
} else {
auto dispatch = [samples = sample_descs_gpu, blocks = first_blocks_gpu, &ctx, need_crop_x,
offset_blk, nonempty_samples](auto pad_v, auto flip_x_v) {
offset_blk, nonempty_samples](auto pad_v, auto flip_x_v, auto out_aligned_v) {
if (need_crop_x) {
SliceHwc2HwcNormalize<Out, In, flip_x_v.value, pad_v.value, kBlockSizeMul * kBlockWidth,
kStaticChannels><<<offset_blk, kThreadBlockSize, 0, ctx.gpu.stream>>>(
samples, blocks, nonempty_samples);
} else {
Hwc2HwcNormalize<Out, In, flip_x_v.value, pad_v.value, kBlockSizeMul * kBlockWidth,
kStaticChannels><<<offset_blk, kThreadBlockSize, 0, ctx.gpu.stream>>>(
samples, blocks, nonempty_samples);
if constexpr (std::is_same_v<Out, float16> && pad_v.value && out_aligned_v.value) {
Hwc2HwcNormalizePadFp16<Out, In, flip_x_v.value, kBlockSizeMul * kBlockWidth,
kStaticChannels>
<<<offset_blk, kThreadBlockSize, 0, ctx.gpu.stream>>>(samples, blocks,
nonempty_samples);
} else {
Hwc2HwcNormalize<Out, In, flip_x_v.value, pad_v.value, kBlockSizeMul * kBlockWidth,
kStaticChannels><<<offset_blk, kThreadBlockSize, 0, ctx.gpu.stream>>>(
samples, blocks, nonempty_samples);
}
}
};

auto dispatch_flip = [&](auto pad_v, bool flip_x) {
auto dispatch_aligned = [&](auto pad_v, auto flip_x_v, bool out_aligned) {
if (out_aligned) {
dispatch(pad_v, flip_x_v, std::true_type{});
} else {
dispatch(pad_v, flip_x_v, std::false_type{});
}
};

auto dispatch_flip = [&](auto pad_v, bool flip_x, bool out_aligned) {
if (flip_x) {
dispatch(pad_v, std::true_type{});
dispatch_aligned(pad_v, std::true_type{}, out_aligned);
} else {
dispatch(pad_v, std::false_type{});
dispatch_aligned(pad_v, std::false_type{}, out_aligned);
}
};

if (need_pad) {
dispatch_flip(std::true_type{}, need_flip_x);
dispatch_flip(std::true_type{}, need_flip_x, outputs_aligned_4);
} else {
dispatch_flip(std::false_type{}, need_flip_x);
dispatch_flip(std::false_type{}, need_flip_x, outputs_aligned_4);
}
}

Expand Down
13 changes: 12 additions & 1 deletion include/dali/core/float16.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -410,4 +410,15 @@ inline __device__ dali::float16 __ldg(const dali::float16 *mem) {
}
#endif

#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 350 // this is for clang-only build
inline __device__ __half2 make_half2(const dali::float16 x, const dali::float16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
return make_half2(x.impl, y.impl);
#else
assert(!"Unreachable code!");
return {};
#endif
}
#endif

#endif // DALI_CORE_FLOAT16_H_