Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JPEG color conversion and chroma subsampling kernel #2771

Merged
merged 17 commits into from
Mar 15, 2021

Conversation

jantonguirao
Copy link
Contributor

@jantonguirao jantonguirao commented Mar 8, 2021

Signed-off-by: Joaquin Anton janton@nvidia.com

Why we need this PR?

Pick one, remove the rest

  • It adds new feature needed for the JPEG artifact augmentation

What happened in this PR?

Fill relevant points, put NA otherwise. Replace anything inside []

  • What solution was applied:
    Added a CUDA kernel for JPEG RGB to YCbCr conversion plus chroma subsampling
  • Affected modules and functionalities:
    New functionality
  • Key points relevant for the review:
    CUDA kernel, performance
  • Validation and testing:
    C++ tests added
  • Documentation (including examples):
    NA

JIRA TASK: [DALI-1905]

template <typename T>
__inline__ __device__ vec<3, T> rgb_to_ycbcr(const vec<3, uint8_t> rgb) {
vec<3, T> ycbcr;
ycbcr.x = rgb_to_y<T>(rgb);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need here as you are not using it when rgb_to_ycbcr is called?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really. Fixed

@jantonguirao jantonguirao changed the title [WIP] Chroma subsampling kernel Add JPEG color converssion and chroma subsampling kernel Mar 10, 2021
@jantonguirao jantonguirao marked this pull request as ready for review March 10, 2021 16:17
@jantonguirao jantonguirao changed the title Add JPEG color converssion and chroma subsampling kernel Add JPEG color conversion and chroma subsampling kernel Mar 10, 2021
return ycbcr;
}

template <bool horz_subsample, bool vert_subsample, typename T = uint8_t, int in_nchannels = 3>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template <bool horz_subsample, bool vert_subsample, typename T = uint8_t, int in_nchannels = 3>
template <bool horz_subsample, bool vert_subsample, typename T = uint8_t>

I don't think that RGBToYCbCr makes sense for anything other than 3 channels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@jantonguirao jantonguirao force-pushed the chroma_subsample branch 2 times, most recently from fe4eddd to 2fd52c8 Compare March 10, 2021 16:27
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
mzient and others added 4 commits March 10, 2021 19:09
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Comment on lines 306 to 308
// chroma_subsample_params_t<uint8_t, false, true>,
// chroma_subsample_params_t<uint8_t, true, false>,
// chroma_subsample_params_t<uint8_t, false, false>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Comment on lines 129 to 126
for (int i = 0; i < N; i++)
*ptr++ = v[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (int i = 0; i < N; i++)
*ptr++ = v[i];
for (int i = 0; i < N; i++)
ptr[i] = v[i];

I think that 2 kinds of indexing are more confusing both for the human reader and compiler.

// Assuming CUDA block has:
// - width 32, leads to 4 horizontal blocks of 8
// - height 8, so a single block 8x8 fits vertically
__shared__ T luma_blk[4][luma_blk_h][luma_blk_w];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
__shared__ T luma_blk[4][luma_blk_h][luma_blk_w];
__shared__ T luma_blk[4][luma_blk_h][luma_blk_strides[1]];

Comment on lines 192 to 189
__shared__ T cb_blk[4][8][8];
__shared__ T cr_blk[4][8][8];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
__shared__ T cb_blk[4][8][8];
__shared__ T cr_blk[4][8][8];
__shared__ T cb_blk[4][chroma_blk_sz[1]][chroma_blk_strides[1]];
__shared__ T cr_blk[4][chroma_blk_sz[1]][chroma_blk_strides[1]];

rgb_to_ycbcr_chroma_subsample<horz_subsample, vert_subsample>(
blk_offset, offset, luma, cb, cr, in);

__syncthreads();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this kernel, the synchronization is not necessary - you're accessing only the elements produced by this thread.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am aware of that. I just put it here, because I was planning to build on top of that. I also don't need the shared memory.

Comment on lines 191 to 193
__shared__ T luma_blk[4][luma_blk_h][luma_blk_w];
__shared__ T cb_blk[4][8][8];
__shared__ T cr_blk[4][8][8];

int blk_idx = threadIdx.x / 8;
int local_x = threadIdx.x % 8;
int local_y = threadIdx.y;
Copy link
Contributor

@mzient mzient Mar 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that having 8x8 blocks for both luma and chroma may simplify the DCT & quantization steps.

Suggested change
__shared__ T luma_blk[4][luma_blk_h][luma_blk_w];
__shared__ T cb_blk[4][8][8];
__shared__ T cr_blk[4][8][8];
int blk_idx = threadIdx.x / 8;
int local_x = threadIdx.x % 8;
int local_y = threadIdx.y;
__shared__ float Cb[2][4][8][9]; // yes, 9 - to reduce bank conflicts!
__shared__ float Cr[2][4][8][9];
__shared__ float Y[2 << vert_subsample][4 << horz_subsample][8][9];
int cx = threadIdx.x;
int cy = threadIdx.y;
chroma_x = cx & 7;
chroma_y = cy & 7;
chroma_bx = cx >> 3;
chroma_by = cy >> 3;
int lx = threadIdx.x << horz_subsample;
int ly = threadIdx.y << vert_subsample;
luma_x = lx & 7;
luma_y = ly & 7;
luma_bx = lx >> 3;
luma_by = ly >> 3;

With 16x32 block size we'd have
Cr 2489 * 4 bytes = 2304 B
Cb 2
489 * 4 bytes = 2304 B
Y 488*9 * 4 bytes = 9216 B
total 13824 B - well within acceptable limits, we still have quite a lot of shared mem for DCT step

Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Comment on lines 232 to 240
if (horz_subsample && vert_subsample) {
luma[luma_y][luma_x + 1] = ycbcr.luma[1];
luma[luma_y + 1][luma_x] = ycbcr.luma[2];
luma[luma_y + 1][luma_x + 1] = ycbcr.luma[3];
} else if (horz_subsample) {
luma[luma_y][luma_x + 1] = ycbcr.luma[1];
} else if (vert_subsample) {
luma[luma_y + 1][luma_x] = ycbcr.luma[1];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be shorter - and hopefully the compiler would be smart enough to unroll it.

Suggested change
if (horz_subsample && vert_subsample) {
luma[luma_y][luma_x + 1] = ycbcr.luma[1];
luma[luma_y + 1][luma_x] = ycbcr.luma[2];
luma[luma_y + 1][luma_x + 1] = ycbcr.luma[3];
} else if (horz_subsample) {
luma[luma_y][luma_x + 1] = ycbcr.luma[1];
} else if (vert_subsample) {
luma[luma_y + 1][luma_x] = ycbcr.luma[1];
}
for (int i = 0, k = 0; i < vert_subsample+1; i++)
for (int j = 0; j < horz_subsample+1; j++, k++)
luma[luma_y + i][luma_x + j] = ycbcr.luma[k];

ivec2 offset{x, y};

auto ycbcr = rgb_to_ycbcr_subsampled<horz_subsample, vert_subsample, T>(offset, in);
luma[luma_y][luma_x] = ycbcr.luma[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this line closer to other luma writes - or follow the suggestion below.

Signed-off-by: Joaquin Anton <janton@nvidia.com>
Comment on lines +198 to +199
int chroma_x = threadIdx.x & 7; // % 8
int chroma_y = threadIdx.y & 7; // % 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked and for both & 7 and % 8 NVCC generates the same code:

mov.u32         %r3, %tid.x;
and.b32         %r4, %r3, 7;

Maybe we should stick to what is intended here then? Unless there is something here I'm not getting.

This goes for every time this trick was used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mzient This was originally your suggestion. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same for /8 vs >>8. For both it is:

mov.u32        %r3, %tid.x;
shr.u32         %r4, %r3, 3;

Copy link
Contributor

@mzient mzient Mar 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My original suggestion was inside a function where these values were passed as signed integers. For signed integers, the division rounds towards zero and gives a negative remainder. It's not only slower (additional math), but also potentially dangerous.

@jantonguirao
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [2167738]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [2167738]: BUILD FAILED

Signed-off-by: Joaquin Anton <janton@nvidia.com>
@jantonguirao
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [2168116]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [2168116]: BUILD PASSED

@jantonguirao jantonguirao merged commit 1ab45a2 into NVIDIA:master Mar 15, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants