Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equalize kernel #4565

Merged
merged 4 commits into from Jan 19, 2023
Merged

Equalize kernel #4565

merged 4 commits into from Jan 19, 2023

Conversation

stiepan
Copy link
Member

@stiepan stiepan commented Jan 13, 2023

Signed-off-by: Kamil Tokarski ktokarski@nvidia.com

Category:

New feature (non-breaking change which adds functionality)

Description:

This PR adds equalize kernel. Equalization consits of following steps:

  • computing the histogram
  • computing the cumulative of histogram
  • preparing a lookup table to remap image values through (the table is basically the cumulative scaled to the image values range)
  • performing the lookup

The lookup table is different for different channels, so the existing lookup kernel did not seem to fit here.

The operation is needed for auto augment pipelines. For now it supports images and videos of uint8 type only.

Additional information:

Affected modules and functionalities:

No existing funcionalities are affected.

Key points relevant for the review:

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: DALI-3187

Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
@stiepan
Copy link
Member Author

stiepan commented Jan 16, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7027870]: BUILD STARTED

Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7027870]: BUILD FAILED

@stiepan
Copy link
Member Author

stiepan commented Jan 16, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7028794]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7028794]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7028794]: BUILD PASSED

@stiepan stiepan mentioned this pull request Jan 16, 2023
18 tasks
static constexpr int hist_range = 256;

/**
* @brief Performs per-channel equalization.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to do per-channel equalization, but equalization on a single-channel input.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input can have multiple channels (the second extent), each of them will get different histogram and lookup tables.

Comment on lines +27 to +30
__global__ void ZeroMem(const SampleDesc *sample_descs) {
auto sample_desc = sample_descs[blockIdx.y];
sample_desc.out[blockIdx.x * SampleDesc::range_size + threadIdx.x] = 0;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this better than cudaMemsetAsync?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't try it, likely not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For single sample it seems to perform sligtly slower, but it is such a slim difference that I am not sure if that is a real thing. One concern is that I'd have to assume the tensor list is contigious here (or make num_sample calls).

static constexpr int64_t kMaxGridSize = 128;
static constexpr int64_t kShmPerChannelSize = SampleDesc::range_size * sizeof(uint64_t);

HistogramKernelGpu() : shared_mem_limit_{GetSharedMemPerBlock()}, sample_descs_{} {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
HistogramKernelGpu() : shared_mem_limit_{GetSharedMemPerBlock()}, sample_descs_{} {}
HistogramKernelGpu() : shared_mem_limit_{GetSharedMemPerBlock()} {}

^^ redundant?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right!

const uint64_t *in;
};

struct LutKernelGpu {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no DLL_PUBLIC here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it.

__shared__ uint64_t workspace[SampleDesc::range_size];
auto sample_desc = sample_descs[blockIdx.x];
PrefixSum(workspace, sample_desc.in);
int32_t first_idx = FirstNonZero(workspace);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need every single thread finding the first non-zero? wonder if it makes a difference?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed elsewhere, we need that value in each thread anyway and there does not seem to be obvious alternative solution that would outperform it.

idx += blockDim.x * gridDim.x) {
const uint8_t *in = sample_desc.in;
uint8_t *out = sample_desc.out;
uint64_t channel_idx = idx % sample_desc.num_channels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why relying on modulus to index the channels, when you could use the y dimenson (threadIdx.y)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To have strided accesses in the small lookup table rather than global input and output.

for (int64_t idx = 0; idx < batch_shape[0].num_elements(); idx++) {
sample_in_view.data[idx] = 51 * sample_idx;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are not running the test, just setting the data. Is that expected?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I missed it, thanks.

Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
@stiepan
Copy link
Member Author

stiepan commented Jan 19, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7056329]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7056329]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7056329]: BUILD PASSED

@stiepan stiepan merged commit 8d4f2a9 into NVIDIA:main Jan 19, 2023
aderylo pushed a commit to zpp-dali-2022/DALI that referenced this pull request Mar 17, 2023
* Adds equalization kernel for uint8 samples
* The kernel computes histogram, lookup table and performs the lookup.

Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
@JanuszL JanuszL mentioned this pull request Sep 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants