Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coordinate Flip GPU operator #1895

Merged
merged 6 commits into from
May 4, 2020
Merged

Conversation

jantonguirao
Copy link
Contributor

@jantonguirao jantonguirao commented Apr 23, 2020

Signed-off-by: Joaquin Anton janton@nvidia.com

Why we need this PR?

  • It adds new feature, Coordinate flip, needed to complete MaskRCNN pipeline

What happened in this PR?

Fill relevant points, put NA otherwise. Replace anything inside []

  • What solution was applied:
    Added Coordinate Flip GPU operator
  • Affected modules and functionalities:
    New operator
  • Key points relevant for the review:
    The operator implementation
  • Validation and testing:
    Python operator tests added
  • Documentation (including examples):
    NA

JIRA TASK: [DALI-1392]


void CoordFlipGPU::RunImpl(workspace_t<GPUBackend> &ws) {
const auto &input = ws.InputRef<GPUBackend>(0);
DALI_ENFORCE(input.type().id() == DALI_FLOAT, "Input is expected to be float");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move to SetupImpl (same as CPU).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

cudaMemcpyAsync(sample_descs_gpu_, sample_descs_.data(), sz, cudaMemcpyHostToDevice, stream));

dim3 block(32, 32);
auto blocks_per_sample = std::max(32, 1024 / batch_size_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this gridDim.x?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just a good enough number. The second term is meant to reduce the number of total blocks in case of a batch with many samples

CUDA_CALL(
cudaMemcpyAsync(sample_descs_gpu_, sample_descs_.data(), sz, cudaMemcpyHostToDevice, stream));

dim3 block(32, 32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any benefit in using 2D blocks instead of 1D of the same volume?
That would probably simplify the addressing a bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@jantonguirao jantonguirao force-pushed the coord_flip_gpu branch 4 times, most recently from 819bf25 to 57cee51 Compare April 27, 2020 16:43
int64_t tid = threadIdx.y * blockDim.x + threadIdx.x;
for (int64_t idx = offset + tid; idx < sample.size; idx += grid_size) {
int d = idx % ndim;
bool flip = static_cast<bool>(sample.flip_dim_mask & (1 << d));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think implicit conversion would do just fine, like it would in if statement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for (int64_t idx = offset + tid; idx < sample.size; idx += grid_size) {
int d = idx % ndim;
bool flip = static_cast<bool>(sample.flip_dim_mask & (1 << d));
sample.out[idx] = flip ? T(1) - sample.in[idx] : sample.in[idx];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flip center?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
@jantonguirao jantonguirao force-pushed the coord_flip_gpu branch 2 times, most recently from 09c81ee to 6c6834a Compare April 29, 2020 07:51
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
@jantonguirao
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1291137]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1291137]: BUILD PASSED

@jantonguirao jantonguirao merged commit 52a984a into NVIDIA:master May 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants