Skip to content

Commit

Permalink
Refine bernoulli and unsqueeze op (#26842) (#26885)
Browse files Browse the repository at this point in the history
* add check for bernoulli and register bool for unsqueeze

* follow comments
  • Loading branch information
zhiqiu committed Sep 2, 2020
1 parent 7495b28 commit dd28cad
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 6 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/operators/bernoulli_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ struct BernoulliCudaFunctor {
__host__ __device__ BernoulliCudaFunctor(int seed) : seed_(seed) {}

__host__ __device__ T operator()(const unsigned int n, const T p) const {
// NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
// lines of error messages if, and it should be refined.
PADDLE_ENFORCE(p >= 0.0 && p <= 1.0,
"The probability should be >=0 and <= 1, but got %f", p);
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/operators/bernoulli_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ namespace operators {

template <typename T>
inline HOSTDEVICE T BernoulliFunctor(T p, T rand) {
PADDLE_ENFORCE_LE(p, 1, platform::errors::OutOfRange(
"The probability should be <= 1, but got %f", p));
PADDLE_ENFORCE_GE(p, 0, platform::errors::OutOfRange(
"The probability should be >= 1, but got %f", p));
PADDLE_ENFORCE_LE(p, 1.0,
platform::errors::OutOfRange(
"The probability should be <= 1, but got %f", p));
PADDLE_ENFORCE_GE(p, 0.0,
platform::errors::OutOfRange(
"The probability should be >= 0, but got %f", p));
return static_cast<T>(rand < p);
}

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/operators/unsqueeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/unsqueeze_op.h"

#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
Expand Down Expand Up @@ -327,26 +329,30 @@ REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
REGISTER_OP_CPU_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
4 changes: 4 additions & 0 deletions paddle/fluid/operators/unsqueeze_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
Expand All @@ -30,6 +31,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
Expand All @@ -38,6 +40,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
Expand All @@ -47,6 +50,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
4 changes: 2 additions & 2 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ def stack(x, axis=0, name=None):
[5.0, 6.0] ] ]
Args:
x (Tensor|list[Tensor]): Input ``x`` can be a single tensor, or a ``list`` of tensors.
If ``x`` is a ``list``, the Tensors in ``x``
x (Tensor|list[Tensor]|tuple[Tensor]): Input ``x`` can be a single tensor, or a ``list`` or ``tuple`` of tensors.
If ``x`` is a ``list`` or ``tuple`` , the Tensors in ``x``
must be of the same shape and dtype. Supported data types: float32, float64, int32, int64.
axis (int, optional): The axis along which all inputs are stacked. ``axis`` range is ``[-(R+1), R+1)``,
where ``R`` is the number of dimensions of the first input tensor ``x[0]``.
Expand Down

0 comments on commit dd28cad

Please sign in to comment.