Skip to content

Commit

Permalink
Merge pull request #5571 from sweetsky0901/my_maxout_op
Browse files Browse the repository at this point in the history
Add maxout operator.
  • Loading branch information
sweetsky0901 committed Nov 20, 2017
2 parents d5be1d4 + 9cb2ff6 commit 7ce06c8
Show file tree
Hide file tree
Showing 9 changed files with 541 additions and 0 deletions.
2 changes: 2 additions & 0 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ set(DEPS_OPS
sequence_softmax_op
sum_op
pool_op
maxout_op
pool_with_index_op
conv_op
conv_transpose_op
Expand All @@ -210,6 +211,7 @@ op_library(sgd_op DEPS selected_rows_functor)
op_library(adagrad_op DEPS selected_rows_functor)
op_library(conv_op DEPS vol2col)
op_library(pool_op DEPS pooling)
op_library(maxout_op DEPS maxouting)
op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op)
Expand Down
2 changes: 2 additions & 0 deletions paddle/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ if(WITH_GPU)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
Expand All @@ -26,6 +27,7 @@ else()
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
endif()

cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
Expand Down
106 changes: 106 additions & 0 deletions paddle/operators/math/maxouting.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/maxouting.h"

namespace paddle {
namespace operators {
namespace math {

// All tensors are in NCHW format, and the groups must be greater than 1
template <typename T>
class MaxOutFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * output,
int groups) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
int fea_size = input_height * input_width;
// c_size means the output size of each sample
int c_size = fea_size * output_channels;
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());

for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
T ele = static_cast<T>(-FLT_MAX);
for (int ph = 0; ph < groups; ++ph) {
T x = input_data[(new_bindex + new_cindex) * groups
+ ph * fea_size + f];
ele = ele > x ? ele : x;
}
output_data[(new_bindex+new_cindex+f)] = ele;
}
}
}
}
};



template <class T>
class MaxOutGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
int fea_size = input_height * input_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

for (int i = 0; i < batch_size; ++i) {
int blen = fea_size * output_channels * i;
for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
int input_idx0 = (blen + clen) * groups + f;
bool continue_match = true;
int output_idx = blen + clen + f;
for (int g = 0; g < groups && continue_match; ++g) {
int input_idx = input_idx0 + fea_size * g;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
}
}
}
}
}
}
};

template class MaxOutGradFunctor<platform::CPUPlace, float>;
template class MaxOutGradFunctor<platform::CPUPlace, double>;
template class MaxOutFunctor<platform::CPUPlace, float>;
template class MaxOutFunctor<platform::CPUPlace, double>;

} // namespace math
} // namespace operators
} // namespace paddle
154 changes: 154 additions & 0 deletions paddle/operators/math/maxouting.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/maxouting.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

template <typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
const int channels,
const int input_height, const int input_width,
int groups, T* output_data ) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
T ele = static_cast<T>(-FLT_MAX);
for (int g = 0; g < groups; ++g) {
T x = input_data[data_idx + g * feat_len];
ele = ele > x ? ele : x;
}
output_data[i] = ele;
}
}
template <typename T>
__global__ void KernelMaxoutGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_height, const int input_width, int groups) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int max_index = -1;
bool continue_match = true;
for (int g = 0; g < groups && continue_match; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[i]) {
max_index = data_idx + g * feat_len;
continue_match = false;
break;
}
}
if (max_index != -1) {
input_grad[max_index] += output_grad[index];
}
}
}
/*
* All tensors are in NCHW format.
*/
template <typename T>
class MaxOutFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output,
int groups) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];

const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelMaxOut<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, input_channels,
input_height, input_width, groups,
output_data);
}
};
/*
* All tensors are in NCHW format.
*/
template <typename T>
class MaxOutGradFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];

const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelMaxoutGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups);
}
};

template class MaxOutGradFunctor<platform::GPUPlace, float>;
template class MaxOutGradFunctor<platform::GPUPlace, double>;

template class MaxOutFunctor<platform::GPUPlace, float>;
template class MaxOutFunctor<platform::GPUPlace, double>;

} // namespace math
} // namespace operators
} // namespace paddle
47 changes: 47 additions & 0 deletions paddle/operators/math/maxouting.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"

namespace paddle {
namespace operators {
namespace math {

#define FLT_MAX \
__FLT_MAX__

template <typename Place, typename T>

class MaxOutFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output,
int groups);
};

template <typename Place, class T>
class MaxOutGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, int groups);
};
} // namespace math
} // namespace operators
} // namespace paddle
Loading

0 comments on commit 7ce06c8

Please sign in to comment.