From 867d3cb5c08e34e98c6c620fed08edf47350f4e6 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sat, 24 Mar 2018 22:44:11 -0700 Subject: [PATCH] "fix" --- .../fluid/operators/math/sequence_pooling.cu | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu index 31b6166432438..5e5fc47119f18 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cu +++ b/paddle/fluid/operators/math/sequence_pooling.cu @@ -144,7 +144,7 @@ class SequencePoolFunctor { framework::Tensor* output, framework::Tensor* index = nullptr) { auto lod = input.lod()[0]; - int item_dim = input.numel() / output->dims()[0]; + const size_t item_dim = input.numel() / output->dims()[0]; dim3 threads(1024, 1); dim3 grid(lod.size(), 1); if (pooltype == "MAX") { @@ -152,13 +152,13 @@ class SequencePoolFunctor { T, MaxPoolFunctor><<>>( MaxPoolFunctor(), input.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, - output->mutable_data(context.GetPlace()), index->data()); + output->mutable_data(context.GetPlace()), index->data()); } else if (pooltype == "AVG") { sequence_pool_kernel< T, AvgPoolFunctor><<>>( AvgPoolFunctor(), input.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, - output->mutable_data(context.GetPlace())); + output->mutable_data(context.GetPlace()), nullptr); } } }; @@ -167,7 +167,7 @@ template struct MaxPoolGradFunctor { HOSTDEVICE void operator()(const T* out_grad, const size_t start, const size_t end, const size_t item_dim, - T* in_grad, int* index) { + T* in_grad, const int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { for (int i = start; i < end; ++i) { if (i == *index) { @@ -184,7 +184,7 @@ template struct AvgPoolGradFunctor { HOSTDEVICE void operator()(const T* out_grad, const size_t start, const size_t end, const size_t item_dim, - T* in_grad, int* index) { + T* in_grad, const int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { for (int i = start; i < end; ++i) { in_grad[item_dim * i + tid] = out_grad[tid] / (end - start); @@ -198,7 +198,7 @@ __global__ void sequence_pool_grad_kernel(Range_OP op, const T* out_grad, const size_t* lod, const size_t lod_size, const size_t item_dim, T* in_grad, - int* index) { + const int* index) { int bid = blockIdx.x; if (bid >= lod_size) return; size_t start = lod[bid]; @@ -215,7 +215,7 @@ class SequencePoolGradFunctor { /* max pool has index */ const framework::Tensor* index = nullptr) { auto lod = in_grad->lod()[0]; - int item_dim = in_grad->numel() / out_grad.dims()[0]; + const size_t item_dim = in_grad->numel() / out_grad.dims()[0]; dim3 threads(1024, 1); dim3 block(lod.size(), 1); if (pooltype == "MAX") { @@ -223,13 +223,13 @@ class SequencePoolGradFunctor { T, MaxPoolGradFunctor><<>>( MaxPoolGradFunctor(), out_grad.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, - in_grad->mutable_data(context.GetPlace()), index->data()); + in_grad->mutable_data(context.GetPlace()), index->data()); } else if (pooltype == "AVG") { sequence_pool_grad_kernel< T, AvgPoolGradFunctor><<>>( AvgPoolGradFunctor(), out_grad.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, - in_grad->mutable_data(context.GetPlace())); + in_grad->mutable_data(context.GetPlace()), nullptr); } } };