Skip to content

Commit

Permalink
"fix"
Browse files Browse the repository at this point in the history
  • Loading branch information
dzhwinter committed Mar 25, 2018
1 parent 2b17df7 commit 867d3cb
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions paddle/fluid/operators/math/sequence_pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,21 +144,21 @@ class SequencePoolFunctor<platform::CUDADeviceContext, T> {
framework::Tensor* output,
framework::Tensor* index = nullptr) {
auto lod = input.lod()[0];
int item_dim = input.numel() / output->dims()[0];
const size_t item_dim = input.numel() / output->dims()[0];
dim3 threads(1024, 1);
dim3 grid(lod.size(), 1);
if (pooltype == "MAX") {
sequence_pool_kernel<
T, MaxPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
MaxPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), index->data<T>());
output->mutable_data<T>(context.GetPlace()), index->data<int>());
} else if (pooltype == "AVG") {
sequence_pool_kernel<
T, AvgPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
AvgPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()));
output->mutable_data<T>(context.GetPlace()), nullptr);
}
}
};
Expand All @@ -167,7 +167,7 @@ template <typename T>
struct MaxPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, int* index) {
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
if (i == *index) {
Expand All @@ -184,7 +184,7 @@ template <typename T>
struct AvgPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, int* index) {
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
in_grad[item_dim * i + tid] = out_grad[tid] / (end - start);
Expand All @@ -198,7 +198,7 @@ __global__ void sequence_pool_grad_kernel(Range_OP op, const T* out_grad,
const size_t* lod,
const size_t lod_size,
const size_t item_dim, T* in_grad,
int* index) {
const int* index) {
int bid = blockIdx.x;
if (bid >= lod_size) return;
size_t start = lod[bid];
Expand All @@ -215,21 +215,21 @@ class SequencePoolGradFunctor<platform::CUDADeviceContext, T> {
/* max pool has index */
const framework::Tensor* index = nullptr) {
auto lod = in_grad->lod()[0];
int item_dim = in_grad->numel() / out_grad.dims()[0];
const size_t item_dim = in_grad->numel() / out_grad.dims()[0];
dim3 threads(1024, 1);
dim3 block(lod.size(), 1);
if (pooltype == "MAX") {
sequence_pool_grad_kernel<
T, MaxPoolGradFunctor<T>><<<block, threads, 0, context.stream()>>>(
MaxPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_grad->mutable_data<T>(context.GetPlace()), index->data<T>());
in_grad->mutable_data<T>(context.GetPlace()), index->data<int>());
} else if (pooltype == "AVG") {
sequence_pool_grad_kernel<
T, AvgPoolGradFunctor<T>><<<block, threads, 0, context.stream()>>>(
AvgPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_grad->mutable_data<T>(context.GetPlace()));
in_grad->mutable_data<T>(context.GetPlace()), nullptr);
}
}
};
Expand Down

0 comments on commit 867d3cb

Please sign in to comment.