-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add maxout operator. #5571
Add maxout operator. #5571
Changes from 21 commits
4a428c8
058bdd3
784fd82
6c7e136
fe1e16b
ab9c71d
bd773b9
494edc6
bb1be5d
9954496
f57cd1e
f319fb1
8d9babf
3ef776e
5802880
63f8c5f
a6a01c1
4c113cc
25d76bc
2d7a652
13d39ea
c645d06
52f2366
76fc1a8
6ac4237
4e5c989
350cc61
3fbff1e
95cbbd7
04fd989
9cb2ff6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/math/maxouting.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
namespace math { | ||
|
||
/* | ||
* All tensors are in NCHW format. | ||
* groups mustbe > 1 | ||
*/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. // All tensors are in NCHW format, and the groups must be greater than 1. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
template <typename T> | ||
class MaxOutFunctor<platform::CPUPlace, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, | ||
framework::Tensor * output, | ||
int groups) { | ||
const int batch_size = input.dims()[0]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = output->dims()[1]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should check whether There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个在外面是这么初始化出来的,所以才没有再检查 |
||
int fea_size = input_height * input_width; | ||
// c_size means the output size of each sample | ||
int c_size = fea_size * output_channels; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. c_size -> out_size ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
const T* input_data = input.data<T>(); | ||
T* output_data = output->mutable_data<T>(context.GetPlace()); | ||
|
||
for (int i = 0; i < batch_size; ++i) { | ||
int new_bindex = c_size * i; | ||
for (int c = 0; c < output_channels; ++c) { | ||
int new_cindex = fea_size * c; | ||
for (int f = 0; f < fea_size; ++f) { | ||
// T ele = maxout_process.initial(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
T ele = static_cast<T>(-FLT_MAX); | ||
for (int ph = 0; ph < groups; ++ph) { | ||
T x = input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems you do not install Install pip install pre-commit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
ele = ele > x ? ele : x; | ||
} | ||
output_data[(new_bindex+new_cindex+f)] = ele; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议去掉maxout_process, 这里直接比大小~ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
|
||
|
||
template <class T> | ||
class MaxOutGradFunctor<platform::CPUPlace, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, | ||
framework::Tensor& input_grad, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
const framework::Tensor& output, | ||
const framework::Tensor& output_grad, | ||
int groups) { | ||
const int batch_size = input.dims()[0]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = output.dims()[1]; | ||
int fea_size = input_height * input_width; | ||
const T* input_data = input.data<T>(); | ||
const T* output_data = output.data<T>(); | ||
const T* output_grad_data = output_grad.data<T>(); | ||
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); | ||
|
||
for (int i = 0; i < batch_size; ++i) { | ||
int blen = fea_size * output_channels * i; | ||
for (int c = 0; c < output_channels; ++c) { | ||
int clen = fea_size * c; | ||
for (int f = 0; f < fea_size; ++f) { | ||
int input_idx0 = (blen + clen) * groups + f; | ||
bool continue_match = true; | ||
int output_idx = blen + clen + f; | ||
for (int g = 0; g < groups && continue_match; ++g) { | ||
int input_idx = input_idx0 + fea_size * g; | ||
input_grad_data[input_idx] = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove line 88. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 初始化为一个值,为什么需要去掉,有时候内存里往往有脏数据呢 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 你可在循环外面像这样初始化 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. input_idx在for循环里是变化的。除非在外面再memset了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if (input_data[input_idx] == output_data[output_idx]) { | ||
input_grad_data[input_idx] += output_grad_data[output_idx]; | ||
continue_match = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
template class MaxOutGradFunctor<platform::CPUPlace, float>; | ||
template class MaxOutGradFunctor<platform::CPUPlace, double>; | ||
template class MaxOutFunctor<platform::CPUPlace, float>; | ||
template class MaxOutFunctor<platform::CPUPlace, double>; | ||
|
||
} // namespace math | ||
} // namespace operators | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/math/maxouting.h" | ||
#include "paddle/platform/cuda_helper.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
namespace math { | ||
|
||
template <typename T> | ||
__global__ void KernelMaxOut(const int nthreads, const T* input_data, | ||
const int channels, | ||
const int input_height, const int input_width, | ||
int groups, T* output_data ) { | ||
const int size = input_height * input_width * channels / groups; | ||
const int feat_len = input_height * input_width; | ||
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
int offset = blockDim.x * gridDim.x; | ||
for (int i = index; i < nthreads; i += offset) { | ||
int batch_idx = i / size; | ||
int batch_offset = i % size; | ||
int channel_idx = batch_offset / feat_len; | ||
int feat_idx = batch_offset % feat_len; | ||
int data_idx = | ||
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; | ||
T ele = static_cast<T>(-FLT_MAX); | ||
for (int g = 0; g < groups; ++g) { | ||
T x = input_data[data_idx + g * feat_len]; | ||
ele = ele > x ? ele : x; | ||
} | ||
output_data[i] = ele; | ||
} | ||
} | ||
template <typename T> | ||
__global__ void KernelMaxoutGrad( | ||
const int nthreads, const T* input_data, const T* output_data, | ||
const T* output_grad, T* input_grad, const int channels, | ||
const int input_height, const int input_width, int groups) { | ||
const int size = input_height * input_width * channels / groups; | ||
const int feat_len = input_height * input_width; | ||
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
int offset = blockDim.x * gridDim.x; | ||
for (int i = index; i < nthreads; i += offset) { | ||
int batch_idx = i / size; | ||
int batch_offset = i % size; | ||
int channel_idx = batch_offset / feat_len; | ||
int feat_idx = batch_offset % feat_len; | ||
int data_idx = | ||
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; | ||
int max_index = -1; | ||
bool continue_match = true; | ||
for (int g = 0; g < groups && continue_match; ++g) { | ||
if (input_data[data_idx + g * feat_len] == output_data[i]) { | ||
max_index = data_idx + g * feat_len; | ||
continue_match = false; | ||
} | ||
} | ||
if (max_index != -1) { | ||
// atomic add | ||
platform::CudaAtomicAdd(input_grad + max_index, output_grad[index]); | ||
} | ||
} | ||
} | ||
/* | ||
* All tensors are in NCHW format. | ||
*/ | ||
template <typename T> | ||
class MaxOutFunctor<platform::GPUPlace, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, framework::Tensor * output, | ||
int groups) { | ||
const int batch_size = input.dims()[0]; | ||
const int input_channels = input.dims()[1]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = output->dims()[1]; | ||
const int output_height = output->dims()[2]; | ||
const int output_width = output->dims()[3]; | ||
|
||
const T* input_data = input.data<T>(); | ||
T* output_data = output->mutable_data<T>(context.GetPlace()); | ||
int nthreads = output->numel(); | ||
int blocks = (nthreads + 1024 - 1) / 1024; | ||
dim3 threads(1024, 1); | ||
dim3 grid(blocks, 1); | ||
|
||
KernelMaxOut< | ||
T><<<grid, threads, 0, | ||
reinterpret_cast<const platform::CUDADeviceContext&>(context) | ||
.stream()>>>(nthreads, input_data, input_channels, | ||
input_height, input_width, groups, | ||
output_data); | ||
} | ||
}; | ||
/* | ||
* All tensors are in NCHW format. | ||
*/ | ||
template <typename T> | ||
class MaxOutGradFunctor<platform::GPUPlace, T> { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, framework::Tensor& input_grad, | ||
const framework::Tensor& output, | ||
const framework::Tensor& output_grad, | ||
int groups) { | ||
const int batch_size = input.dims()[0]; | ||
const int input_channels = input.dims()[1]; | ||
const int input_height = input.dims()[2]; | ||
const int input_width = input.dims()[3]; | ||
const int output_channels = output.dims()[1]; | ||
const int output_height = output.dims()[2]; | ||
const int output_width = output.dims()[3]; | ||
|
||
const T* input_data = input.data<T>(); | ||
const T* output_data = output.data<T>(); | ||
const T* output_grad_data = output_grad.data<T>(); | ||
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); | ||
int nthreads = output.numel(); | ||
int blocks = (nthreads + 1024 - 1) / 1024; | ||
dim3 threads(1024, 1); | ||
dim3 grid(blocks, 1); | ||
|
||
KernelMaxoutGrad< | ||
T><<<grid, threads, 0, | ||
reinterpret_cast<const platform::CUDADeviceContext&>(context) | ||
.stream()>>>( | ||
nthreads, input_data, output_data, output_grad_data, input_grad_data, | ||
input_channels, input_height, input_width, groups); | ||
} | ||
}; | ||
|
||
template class MaxOutGradFunctor<platform::GPUPlace, float>; | ||
template class MaxOutGradFunctor<platform::GPUPlace, double>; | ||
|
||
template class MaxOutFunctor<platform::GPUPlace, float>; | ||
template class MaxOutFunctor<platform::GPUPlace, double>; | ||
|
||
} // namespace math | ||
} // namespace operators | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include "paddle/framework/tensor.h" | ||
#include "paddle/platform/device_context.h" | ||
#include "paddle/platform/hostdevice.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If remove line 26 - line 48, please also remove this header. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
namespace paddle { | ||
namespace operators { | ||
namespace math { | ||
|
||
#define FLT_MAX \ | ||
__FLT_MAX__ | ||
|
||
template <typename Place, typename T> | ||
|
||
class MaxOutFunctor { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, framework::Tensor * output, | ||
int groups); | ||
}; | ||
|
||
template <typename Place, class T> | ||
class MaxOutGradFunctor { | ||
public: | ||
void operator()(const platform::DeviceContext& context, | ||
const framework::Tensor& input, | ||
framework::Tensor& input_grad, | ||
const framework::Tensor& output, | ||
const framework::Tensor& output_grad, int groups); | ||
}; | ||
} // namespace math | ||
} // namespace operators | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments are not right, same as following.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done