Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conv shift op: change to CamelCase & fix bug #5558

Merged
merged 3 commits into from
Nov 16, 2017

Conversation

mkliegl
Copy link
Contributor

@mkliegl mkliegl commented Nov 10, 2017

Addressing part of issue #5549 .

@mkliegl mkliegl force-pushed the conv_shift_fix_camel_case branch 2 times, most recently from dffdfa7 to 77b400b Compare November 14, 2017 01:39
@mkliegl
Copy link
Contributor Author

mkliegl commented Nov 14, 2017

@chengduoZH The tests pass now after rebasing. Could you please take a look?

int batch_size) {
__global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width,
int y_width, int y_half_width,
int batch_size) {
extern __shared__ T mem[];

int tx = threadIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the problem mentioned in issue has been solved, there is a problem with the ConvShiftForward function.
Please recheck line 62~68, you should not use return before __syncthreads();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Let me fix this now quickly.

@chengduoZH
Copy link
Contributor

@mkliegl
I can approve it, but you should write a new PR to solve above problem. Because the problem is a little bit serious.

@chengduoZH
Copy link
Contributor

If some functions are used in Op, the functions are generally written in the form of functor, such as LSTM.

chengduoZH
chengduoZH previously approved these changes Nov 14, 2017
@mkliegl mkliegl changed the title conv shift op: change to CamelCase conv shift op: change to CamelCase & fix bug Nov 14, 2017
@mkliegl
Copy link
Contributor Author

mkliegl commented Nov 14, 2017

@chengduoZH Thank you for your review! I added a fix for the issue you pointed out now rather than making another PR. Could you please check one more time?

Regarding making functors: It doesn't seem likely that functors for the conv shift forward and backward kernels are going to be needed anywhere else, so I'm not sure it is worth making this change now. But if you still prefer this for consistency or other reasons, let me know and I can do it. Or else please feel free to just make the changes yourself.

int batch_size) {
__global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width,
int y_width, int y_half_width,
int batch_size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please reorder parameter.
Function_Parameter_Ordering
The following function parameters are the same.

@@ -160,20 +160,20 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
auto stream = context.cuda_device_context().stream();

const int x_per_block = 256;
int num_x_blocks = div_up(x_width, x_per_block);
int num_x_blocks = DivUp(x_width, x_per_block);
dim3 grid_dim(num_x_blocks, y_width, batch_size);

if (dX) {
T *dx_data = dX->mutable_data<T>(context.GetPlace());
cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can clean memory in this way.

@chengduoZH
Copy link
Contributor

@mkliegl
In that case, I think you can write it in this way(not define functor) now. If these functions are needed in other Op, we can change them.

Copy link
Contributor

@chengduoZH chengduoZH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chengduoZH chengduoZH merged commit dc78f3c into PaddlePaddle:develop Nov 16, 2017
@mkliegl mkliegl deleted the conv_shift_fix_camel_case branch November 16, 2017 03:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants