Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add maxout operator. #5571

Merged
merged 31 commits into from
Nov 20, 2017
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4a428c8
this for maxout op new add
Nov 11, 2017
058bdd3
this for maxout op new add
Nov 11, 2017
784fd82
resolve conflicts
Nov 11, 2017
6c7e136
Merge branch 'develop' into my_maxout_op
sweetsky0901 Nov 13, 2017
fe1e16b
Merge branch 'develop' into my_maxout_op
sweetsky0901 Nov 13, 2017
ab9c71d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 13, 2017
bd773b9
modify for maxoutop code review
Nov 14, 2017
494edc6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 14, 2017
bb1be5d
merge cmakelist
Nov 14, 2017
9954496
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 14, 2017
f57cd1e
del a err comments
Nov 14, 2017
f319fb1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 14, 2017
8d9babf
maxout code review 2nd
Nov 15, 2017
3ef776e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 15, 2017
5802880
update maxoutop for code review 3
Nov 19, 2017
63f8c5f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 19, 2017
a6a01c1
add test_maxout_op framework to fluis
Nov 19, 2017
4c113cc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
25d76bc
modify for add a space in maxout op
Nov 20, 2017
2d7a652
del framework test_maxout_op
Nov 20, 2017
13d39ea
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
c645d06
add a space + *
Nov 20, 2017
52f2366
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
76fc1a8
for code review 4
Nov 20, 2017
6ac4237
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
4e5c989
rename back
sweetsky0901 Nov 20, 2017
350cc61
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
3fbff1e
for code review 5
Nov 20, 2017
95cbbd7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 20, 2017
04fd989
for code review 6
Nov 20, 2017
9cb2ff6
del num_channels
Nov 20, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ set(DEPS_OPS
sequence_softmax_op
sum_op
pool_op
maxout_op
pool_with_index_op
conv_op
conv_transpose_op
Expand All @@ -210,6 +211,7 @@ op_library(sgd_op DEPS selected_rows_functor)
op_library(adagrad_op DEPS selected_rows_functor)
op_library(conv_op DEPS vol2col)
op_library(pool_op DEPS pooling)
op_library(maxout_op DEPS maxouting)
op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op)
Expand Down
2 changes: 2 additions & 0 deletions paddle/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ if(WITH_GPU)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
Expand All @@ -26,6 +27,7 @@ else()
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
endif()

cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
Expand Down
111 changes: 111 additions & 0 deletions paddle/operators/math/maxouting.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/maxouting.h"

namespace paddle {
namespace operators {
namespace math {

/*
* All tensors are in NCHW format.
* groups mustbe > 1
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments are not right, same as following.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// All tensors are in NCHW format, and the groups must be greater than 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

template <typename T>
class MaxOutFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * output,
int groups) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check whether output_channels and input.dims()[1] / groups are equal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在外面是这么初始化出来的,所以才没有再检查

int fea_size = input_height * input_width;
// c_size means the output size of each sample
int c_size = fea_size * output_channels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c_size -> out_size ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());

for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
// T ele = maxout_process.initial();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

T ele = static_cast<T>(-FLT_MAX);
for (int ph = 0; ph < groups; ++ph) {
T x = input_data[(new_bindex + new_cindex) * groups
+ ph * fea_size + f];
ele = ele > x ? ele : x;
}
output_data[(new_bindex+new_cindex+f)] = ele;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议去掉maxout_process, 这里直接比大小~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}
}
}
};



template <class T>
class MaxOutGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor& input_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

framework::Tensor& input_grad -> framework::Tensor* input_grad, Maybe better to put the output as the last arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
int fea_size = input_height * input_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());

for (int i = 0; i < batch_size; ++i) {
int blen = fea_size * output_channels * i;
for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
int input_idx0 = (blen + clen) * groups + f;
bool continue_match = true;
int output_idx = blen + clen + f;
for (int g = 0; g < groups && continue_match; ++g) {
int input_idx = input_idx0 + fea_size * g;
input_grad_data[input_idx] = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove line 88.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

初始化为一个值,为什么需要去掉,有时候内存里往往有脏数据呢

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你可在循环外面像这样初始化

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_idx在for循环里是变化的。除非在外面再memset了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I don't think you should do cumulative operations here.
    input_grad_data[input_idx] += output_grad_data[output_idx]
  2. You can replace continue_match = false; with break;.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}
}
}
}
}
};

template class MaxOutGradFunctor<platform::CPUPlace, float>;
template class MaxOutGradFunctor<platform::CPUPlace, double>;
template class MaxOutFunctor<platform::CPUPlace, float>;
template class MaxOutFunctor<platform::CPUPlace, double>;

} // namespace math
} // namespace operators
} // namespace paddle
153 changes: 153 additions & 0 deletions paddle/operators/math/maxouting.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/maxouting.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

template <typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
const int channels,
const int input_height, const int input_width,
int groups, T* output_data ) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
T ele = static_cast<T>(-FLT_MAX);
for (int g = 0; g < groups; ++g) {
T x = input_data[data_idx + g * feat_len];
ele = ele > x ? ele : x;
}
output_data[i] = ele;
}
}
template <typename T>
__global__ void KernelMaxoutGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_height, const int input_width, int groups) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int max_index = -1;
bool continue_match = true;
for (int g = 0; g < groups && continue_match; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[i]) {
max_index = data_idx + g * feat_len;
continue_match = false;
}
}
if (max_index != -1) {
// atomic add
platform::CudaAtomicAdd(input_grad + max_index, output_grad[index]);
}
}
}
/*
* All tensors are in NCHW format.
*/
template <typename T>
class MaxOutFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output,
int groups) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];

const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelMaxOut<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, input_channels,
input_height, input_width, groups,
output_data);
}
};
/*
* All tensors are in NCHW format.
*/
template <typename T>
class MaxOutGradFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];

const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelMaxoutGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups);
}
};

template class MaxOutGradFunctor<platform::GPUPlace, float>;
template class MaxOutGradFunctor<platform::GPUPlace, double>;

template class MaxOutFunctor<platform::GPUPlace, float>;
template class MaxOutFunctor<platform::GPUPlace, double>;

} // namespace math
} // namespace operators
} // namespace paddle
47 changes: 47 additions & 0 deletions paddle/operators/math/maxouting.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
Copy link
Contributor

@qingqing01 qingqing01 Nov 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If remove line 26 - line 48, please also remove this header.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


namespace paddle {
namespace operators {
namespace math {

#define FLT_MAX \
__FLT_MAX__

template <typename Place, typename T>

class MaxOutFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output,
int groups);
};

template <typename Place, class T>
class MaxOutGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, int groups);
};
} // namespace math
} // namespace operators
} // namespace paddle
Loading