Skip to content

Commit

Permalink
resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfeiyu committed Aug 25, 2020
2 parents 779e226 + 6f69fbc commit a6fd466
Show file tree
Hide file tree
Showing 13 changed files with 506 additions and 76 deletions.
17 changes: 14 additions & 3 deletions paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1134,9 +1134,20 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
dout * static_cast<T>(alpha) * x.exp() *
(x <= static_cast<T>(0)).template cast<T>();
auto temp_a_pos = static_cast<T>(alpha > 0);
auto temp_a_neg = static_cast<T>(alpha <= 0);
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();

// dx = dout, if alpha > 0 and x > 0
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
// dx = 0, if alpha <= 0 and x <=0
dx.device(d) =
dout * temp_a_pos * temp_x_pos +
dout * static_cast<T>(alpha) * x.exp() * temp_a_pos * temp_x_neg +
dout * (static_cast<T>(1) + static_cast<T>(alpha) * x.exp()) *
temp_a_neg * temp_x_pos;
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/grid_sampler_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"

namespace paddle {
Expand Down
95 changes: 65 additions & 30 deletions paddle/fluid/operators/pixel_shuffle_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,59 @@ class PixelShuffleOp : public framework::OperatorWithKernel {
"Output(Out) of PixelShuffleOp should not be null."));

auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
input_dims.size()));
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));

auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");

PADDLE_ENFORCE_EQ(input_dims[1] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[1], upscale_factor * upscale_factor));

const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC");

if (!channel_last) {
PADDLE_ENFORCE_EQ(
input_dims[1] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[1], upscale_factor * upscale_factor));
} else {
PADDLE_ENFORCE_EQ(
input_dims[3] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[3], upscale_factor * upscale_factor));
}
auto output_dims = input_dims;
output_dims[0] = input_dims[0];
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] * upscale_factor;
if (!channel_last) {
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] * upscale_factor;
} else {
output_dims[1] = input_dims[1] * upscale_factor;
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] / (upscale_factor * upscale_factor);
}
ctx->SetOutputDim("Out", output_dims);
}
};

class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensor, default Tensor<float>), "
"the input feature data of PixelShuffleOp, the layout is [N C H W].");
AddOutput(
"Out",
"(Tensor, default Tensor<float>), the output of "
"PixelShuffleOp. The layout is [N,C/factor^2,H*factor,W*factor].");
AddInput("X",
"(Tensor, default Tensor<float>), "
"the input feature data of PixelShuffleOp, the layout is [N, C, "
"H, W] or [N, H, W, C].");
AddOutput("Out",
"(Tensor, default Tensor<float>), the output of "
"PixelShuffleOp. The layout is [N, C/factor^2, H*factor, "
"W*factor] or [N, H*factor, W*factor, C/factor^2].");
AddAttr<int>("upscale_factor",
"the factor to increase spatial resolution by.")
.SetDefault(1)
Expand All @@ -70,6 +89,11 @@ class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
platform::errors::InvalidArgument(
"upscale_factor should be larger than 0."));
});
AddAttr<std::string>(
"data_format",
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\", Specify the data format of the input data.")
.SetDefault("NCHW");

AddComment(R"DOC(
Pixel Shuffle operator
Expand Down Expand Up @@ -114,19 +138,30 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
platform::errors::NotFound("Output(X@Grad) should not be null"));

auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(
do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
do_dims.size()));
PADDLE_ENFORCE_EQ(do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
do_dims.size()));

auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");

const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC");

auto dx_dims = do_dims;
dx_dims[0] = do_dims[0];
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] / upscale_factor;

if (!channel_last) {
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] / upscale_factor;
} else {
dx_dims[1] = do_dims[1] / upscale_factor;
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] * (upscale_factor * upscale_factor);
}
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
}
};
Expand Down
40 changes: 32 additions & 8 deletions paddle/fluid/operators/pixel_shuffle_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ limitations under the License. */

#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
Expand All @@ -24,23 +25,33 @@ class PixelShuffleOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");

out->mutable_data<T>(ctx.GetPlace());

int factor = ctx.Attr<int>("upscale_factor");

std::string data_format = ctx.Attr<std::string>("data_format");
bool channel_last = (data_format == "NHWC");

auto in_dims = in->dims();
auto o_dims = out->dims();

framework::Tensor t;
t.ShareDataWith(*in);
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});

if (!channel_last) {
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});
} else {
t.Resize({in_dims[0], in_dims[1], in_dims[2], o_dims[3], factor, factor});
}
std::vector<int> axis = {0, 1, 4, 2, 5, 3};

framework::Tensor o;
o.ShareDataWith(*out);
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});

if (!channel_last) {
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});
} else {
o.Resize({in_dims[0], in_dims[1], factor, in_dims[2], factor, o_dims[3]});
}
math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
Expand All @@ -58,19 +69,32 @@ class PixelShuffleGradOpKernel : public framework::OpKernel<T> {

int factor = ctx.Attr<int>("upscale_factor");

std::string data_format = ctx.Attr<std::string>("data_format");
bool channel_last = (data_format == "NHWC");

auto do_dims = dout->dims();
auto dx_dims = dx->dims();

framework::Tensor t;
t.ShareDataWith(*dout);
t.Resize({do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});

if (!channel_last) {
t.Resize(
{do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});
} else {
t.Resize(
{do_dims[0], dx_dims[1], factor, dx_dims[2], factor, do_dims[3]});
}
std::vector<int> axis = {0, 1, 3, 5, 2, 4};

framework::Tensor o;
o.ShareDataWith(*dx);
o.Resize({do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});

if (!channel_last) {
o.Resize(
{do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});
} else {
o.Resize(
{do_dims[0], dx_dims[1], dx_dims[2], do_dims[3], factor, factor});
}
math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
Expand Down
4 changes: 2 additions & 2 deletions paddle/scripts/paddle_build.bat
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ call "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\vcvarsall.bat" amd6
set build_times=1
:build_tp
echo Build third_party for %build_times% time:
msbuild /m /p:Configuration=Release /verbosity:minimal third_party.vcxproj
msbuild /m /p:Configuration=Release /verbosity:quiet third_party.vcxproj
if %ERRORLEVEL% NEQ 0 (
set /a build_times=%build_times%+1
if %build_times% GTR 3 (
Expand All @@ -159,7 +159,7 @@ echo Build third_party successfully!
set build_times=1
:build_paddle
echo Build Paddle for %build_times% time:
msbuild /m /p:Configuration=Release /verbosity:quiet paddle.sln
msbuild /m /p:Configuration=Release /verbosity:minimal paddle.sln
if %ERRORLEVEL% NEQ 0 (
set /a build_times=%build_times%+1
if %build_times% GTR 2 (
Expand Down
60 changes: 60 additions & 0 deletions python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,30 @@ def amp(self, flag):

@property
def amp_configs(self):
"""
Set automatic mixed precision training configurations. In general, amp has serveral configurable
settings that can be configured through a dict.
**Notes**:
**init_loss_scaling(float)**: The initial loss scaling factor. Default 32768.
**use_dynamic_loss_scaling(bool)**: Whether to use dynamic loss scaling. Default True.
**incr_every_n_steps(int)**: Increases loss scaling every n consecutive steps with finite gradients. Default 1000.
**decr_every_n_nan_or_inf(int)**: Decreases loss scaling every n accumulated steps with nan or inf gradients. Default 2.
**incr_ratio(float)**: The multiplier to use when increasing the loss scaling. Default 2.0.
**decr_ratio(float)**: The less-than-one-multiplier to use when decreasing the loss scaling. Default 0.5.
**custom_white_list(list[str])**: Users' custom white list which always execution fp16.
**custom_black_list(list[str])**: Users' custom black list which forbidden execution fp16.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.amp = True
strategy.amp_configs = {
"init_loss_scaling": 32768,
"custom_white_list": ['conv2d']}
"""
return get_msg_dict(self.strategy.amp_configs)

@amp_configs.setter
Expand Down Expand Up @@ -620,6 +644,20 @@ def localsgd_configs(self, configs):

@property
def dgc(self):
"""
Indicating whether we are using Deep Gradient Compression training. For more details, please refer to
[Deep Gradient Compression](https://arxiv.org/abs/1712.01887).
Default Value: False
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.dgc = True # by default this is false
"""
return self.strategy.dgc

@dgc.setter
Expand All @@ -631,6 +669,28 @@ def dgc(self, flag):

@property
def dgc_configs(self):
"""
Set Deep Gradient Compression training configurations. In general, dgc has serveral configurable
settings that can be configured through a dict.
**Notes**:
**rampup_begin_step(int)**: The beginning step from which gradient compression is implemented. Default 0.
**rampup_step(int)**: Time steps used in sparsity warm-up periods. Default is 1.
For example, if the sparsity is [0.75, 0.9375, 0.984375, 0.996, 0.999], and the rampup_step is 100,
it will use 0.75 at 0~19 steps, and 0.9375 at 20~39 steps, and so on. And when reach sparsity array
ends, it will use 0.999 then and after.
**sparsity(list[float])**: Get top important element from gradient tensor, the ratio is (1 - sparsity).
Default is [0.999]. For example, if the sparsity is [0.99, 0.999], the top [1%, 0.1%] important
element will be transmitted.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.dgc = True
strategy.dgc_configs = {"rampup_begin_step": 1252}
"""
return get_msg_dict(self.strategy.dgc_configs)

@dgc_configs.setter
Expand Down
Loading

0 comments on commit a6fd466

Please sign in to comment.