Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into doc/api1
Browse files Browse the repository at this point in the history
  • Loading branch information
dzhwinter committed Jun 15, 2018
2 parents 7ad46ec + 566a940 commit 4970414
Show file tree
Hide file tree
Showing 32 changed files with 1,058 additions and 446 deletions.
2 changes: 1 addition & 1 deletion doc/v2/dev/contribute_to_paddle_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ no changes added to commit (use "git add" and/or "git commit -a")
➜ docker run -it -v $(pwd):/paddle paddle:latest-dev bash -c "cd /paddle/build && ctest"
```

关于构建和测试的更多信息,请参见[这篇文档](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/getstarted/build_and_install/docker_install_cn.rst)
关于构建和测试的更多信息,请参见[使用Docker安装运行](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/v2/build_and_install/docker_install_cn.rst)

## 提交(commit)

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/inference/tensorrt/convert/op_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class OpConverter {
(*it)(op, scope, test_mode);
}

// convert fluid block to tensorrt network
// Convert a fluid block to tensorrt network, NOTE it just convert operators,
// the INetwork's inputs and outputs should specified in some other modules.
void ConvertBlock(const framework::proto::BlockDesc& block,
const std::unordered_set<std::string>& parameters,
const framework::Scope& scope, TensorRTEngine* engine) {
Expand Down
32 changes: 23 additions & 9 deletions paddle/fluid/inference/tensorrt/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ class TensorRTEngine : public EngineBase {
nvinfer1::Weights w_;
};

TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream,
TensorRTEngine(int max_batch, int max_workspace,
cudaStream_t* stream = nullptr,
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
max_workspace_(max_workspace),
stream_(stream),
stream_(stream ? stream : &default_stream_),
logger_(logger) {}

virtual ~TensorRTEngine();
Expand Down Expand Up @@ -121,6 +122,8 @@ class TensorRTEngine : public EngineBase {
// the max memory size the engine uses
int max_workspace_;
cudaStream_t* stream_;
// If stream_ is not set from outside, hold its own stream.
cudaStream_t default_stream_;
nvinfer1::ILogger& logger_;

std::vector<Buffer> buffers_;
Expand Down Expand Up @@ -165,20 +168,31 @@ class TensorRTEngine : public EngineBase {
*/
class TRT_EngineManager {
public:
TensorRTEngine* Create(int max_batch, int max_workspace,
cudaStream_t* stream) {
engines_.emplace_back(new TensorRTEngine(max_batch, max_workspace, stream));
return engines_.back().get();
bool HasEngine(const std::string& name) const {
return engines_.count(name) != 0;
}

// Get an engine called `name`.
TensorRTEngine* Get(const std::string& name) const {
return engines_.at(name).get();
}

// Create or get an engine called `name`
TensorRTEngine* Create(int max_batch, int max_workspace, cudaStream_t* stream,
const std::string& name) {
auto* p = new TensorRTEngine(max_batch, max_workspace, stream);
engines_[name].reset(p);
return p;
}

void DeleteALl() {
for (auto& ptr : engines_) {
ptr.reset(nullptr);
for (auto& item : engines_) {
item.second.reset(nullptr);
}
}

private:
std::vector<std::unique_ptr<TensorRTEngine>> engines_;
std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_;
};

} // namespace tensorrt
Expand Down
51 changes: 25 additions & 26 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ Sigmoid Activation Operator
__attribute__((unused)) constexpr char LogSigmoidDoc[] = R"DOC(
Logsigmoid Activation Operator
$$out = \log \frac{1}{1 + e^{-x}}$$
$$out = \\log \\frac{1}{1 + e^{-x}}$$
)DOC";

Expand Down Expand Up @@ -252,15 +252,14 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
AddComment(R"DOC(
Softshrink Activation Operator.
:strong:`Softshrink Activation Operator`
$$
out = \begin{cases}
x - \lambda, \text{if } x > \lambda \\
x + \lambda, \text{if } x < -\lambda \\
0, \text{otherwise}
\end{cases}
$$
.. math::
out = \begin{cases}
x - \lambda, \text{if } x > \lambda \\
x + \lambda, \text{if } x < -\lambda \\
0, \text{otherwise}
\end{cases}
)DOC");
}
Expand All @@ -271,18 +270,18 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "Input of HardShrink operator");
AddOutput("Out", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink")
AddAttr<float>("threshold",
"The value of threshold for HardShrink. [default: 0.5]")
.SetDefault(0.5f);
AddComment(R"DOC(
HardShrink Activation Operator.
:strong:`HardShrink activation operator`
$$
out = \begin{cases}
x, \text{if } x > \lambda \\
x, \text{if } x < -\lambda \\
0, \text{otherwise}
\end{cases}
$$
.. math::
out = \begin{cases}
x, \text{if } x > \lambda \\
x, \text{if } x < -\lambda \\
0, \text{otherwise}
\end{cases}
)DOC");
}
Expand Down Expand Up @@ -394,18 +393,18 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Out", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation")
AddAttr<float>("threshold",
"The threshold location of activation. [default 1.0].")
.SetDefault(1.0f);
AddComment(R"DOC(
ThresholdedRelu Activation Operator.
:strong:`ThresholdedRelu activation operator`
$$
out = \begin{cases}
x, \text{if } x > threshold \\
0, \text{otherwise}
\end{cases}
$$
.. math::
out = \begin{cases}
x, \text{if } x > threshold \\
0, \text{otherwise}
\end{cases}
)DOC");
}
};
Expand Down
34 changes: 15 additions & 19 deletions paddle/fluid/operators/compare_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,26 @@ class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
OpComment comment;
AddInput("X",
string::Sprintf("(LoDTensor) the left hand operand of %s operator",
comment.type));
AddInput("Y", string::Sprintf(
"(LoDTensor) the right hand operand of %s operator",
comment.type));
AddInput("X", string::Sprintf("the left hand operand of %s operator",
comment.type));
AddInput("Y", string::Sprintf("the right hand operand of %s operator",
comment.type));
AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu "
"Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
"device")
.SetDefault(false);
AddOutput("Out", string::Sprintf(
"(LoDTensor) n-dim bool tensor. Each element is %s",
comment.equation));
AddComment(string::Sprintf(R"DOC(%s Operator
"device [default true].")
.SetDefault(true);
AddOutput("Out", string::Sprintf("n-dim bool tensor. Each element is %s",
comment.equation));
AddComment(string::Sprintf(R"DOC(
It operates element-wise on X and Y, and returns the Out. Each of them is a
N-dim tensor. X and Y could be any type. The each element of the Out tensor is
calculated by %s
calculated by $%s$
)DOC",
comment.type, comment.equation));
AddAttr<int>("axis",
"(int, default -1). The start dimension index "
"for broadcasting Y onto X.")
comment.equation));
AddAttr<int>(
"axis",
"The start dimension index for broadcasting Y onto X. [default -1]")
.SetDefault(-1)
.EqualGreaterThan(-1);
}
Expand Down
14 changes: 7 additions & 7 deletions paddle/fluid/operators/cumsum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ class CumOp : public framework::OperatorWithKernel {
class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Cumsum operator");
AddOutput("Out", "Output of Cumsum operator");
AddInput("X", "Input of cumsum operator");
AddOutput("Out", "Output of cumsum operator");
AddAttr<int>("axis",
"(int, default -1). The dimenstion to accumulate along. "
"-1 means the last dimenstion")
"The dimenstion to accumulate along. -1 means the last "
"dimenstion [default -1].")
.SetDefault(-1)
.EqualGreaterThan(-1);
AddAttr<bool>("exclusive",
"bool, default false). Whether to perform exclusive cumsum")
"Whether to perform exclusive cumsum. [default false].")
.SetDefault(false);
AddAttr<bool>("reverse",
"bool, default false). If true, the cumsum is performed in "
"the reversed direction")
"If true, the cumsum is performed in the reversed direction. "
"[default false].")
.SetDefault(false);
AddComment(R"DOC(
The cumulative sum of the elements along a given axis.
Expand Down
41 changes: 27 additions & 14 deletions paddle/fluid/operators/detection/box_coder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,36 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"and M represents the number of deocded boxes.");

AddComment(R"DOC(
Bounding Box Coder Operator.
Bounding Box Coder.
Encode/Decode the target bounding box with the priorbox information.
The Encoding schema described below:
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv
ow = log(abs(tw / pw)) / pwv
oh = log(abs(th / ph)) / phv
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv
ow = log(abs(tw / pw)) / pwv
oh = log(abs(th / ph)) / phv
The Decoding schema described below:
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = exp(pwv * tw) * pw + tw / 2
oh = exp(phv * th) * ph + th / 2
where tx, ty, tw, th denote the target box's center coordinates, width and
height respectively. Similarly, px, py, pw, ph denote the priorbox's(anchor)
center coordinates, width and height. pxv, pyv, pwv, phv denote the variance
of the priorbox and ox, oy, ow, oh denote the encoded/decoded coordinates,
width and height.
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = exp(pwv * tw) * pw + tw / 2
oh = exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, width
and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the
priorbox's (anchor) center coordinates, width and height. `pxv`, `pyv`, `pwv`,
`phv` denote the variance of the priorbox and `ox`, `oy`, `ow`, `oh` denote the
encoded/decoded coordinates, width and height.
)DOC");
}
};
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/operators/gaussian_random_batch_size_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
void Apply() override {
AddAttr<float>("mean",
"(float, default 0.0) "
"mean of random tensor.")
"The mean (or center) of the gaussian distribution.")
.SetDefault(.0f);
AddAttr<float>("std",
"(float, default 1.0) "
"std of random tensor.")
"The standard deviation (std, or spread) of the "
"gaussian distribution.")
.SetDefault(1.0f);
AddAttr<int>("seed",
"(int, default 0) "
Expand All @@ -55,9 +56,11 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
.SetDefault(framework::proto::VarType::FP32);

AddComment(R"DOC(
GaussianRandom Operator.
Used to initialize tensors with gaussian random generator.
The defalut mean of the distribution is 0. and defalut standard
deviation (std) of the distribution is 1.. Uers can set mean and std
by input arguments.
)DOC");
}
};
Expand Down
33 changes: 17 additions & 16 deletions paddle/fluid/operators/layer_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,47 +62,48 @@ class LayerNormOp : public framework::OperatorWithKernel {
class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(LoDTensor) The input tensor.");
AddInput("X", "The input tensor.");
AddInput("Scale",
"(Tensor, optional) Scale is a 1-dimensional tensor of size "
"(optional) Scale is a 1-dimensional tensor of size "
"H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])."
"It is applied to the output.")
.AsDispensable();
AddInput("Bias",
"(Tensor, optional) Bias is a 1-dimensional tensor of size "
"(optional) Bias is a 1-dimensional tensor of size "
"H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])."
"It is applied to the output.")
.AsDispensable();
AddOutput("Y", "(LoDTensor) Result after normalization.");
AddOutput("Mean", "(Tensor) Mean of the current mini batch.")
.AsIntermediate();
AddOutput("Variance", "(Tensor) Variance of the current mini batch.")
AddOutput("Y", "Result after normalization.");
AddOutput("Mean", "Mean of the current mini batch.").AsIntermediate();
AddOutput("Variance", "Variance of the current mini batch.")
.AsIntermediate();

AddAttr<float>("epsilon",
"(float, default 1e-5) Constant for "
"numerical stability")
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
"'epsilon' should be between 0.0 and 0.001.");
});
AddAttr<int>("begin_norm_axis",
"(int default:1), the "
"axis of `begin_norm_axis ... Rank(X) - 1` will be "
"the axis of `begin_norm_axis ... Rank(X) - 1` will be "
"normalized. `begin_norm_axis` splits the tensor(`X`) to a "
"matrix [N,H].")
"matrix [N,H]. [default 1].")
.SetDefault(1)
.AddCustomChecker([](const int &begin_norm_axis) {
PADDLE_ENFORCE_GT(begin_norm_axis, 0,
"'begin_norm_axis' should be greater than zero.");
});

AddComment(R"DOC(
Layer Normalization.
Layer Norm has been implemented as discussed in the paper:
https://arxiv.org/abs/1607.06450
...
Assume feature vectors exist on dimensions
:attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics
along these dimensions for each feature vector :math:`a` with size
:math:`H`, then normalize each feature vector using the corresponding
statistics. After that, apply learnable gain and bias on the normalized
tensor to scale and shift if :attr:`scale` and :attr:`shift` are set.
Refer to `Layer Normalization <https://arxiv.org/pdf/1607.06450v1.pdf>`_
)DOC");
}
};
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
};

void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
// Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
exit(0);
}

Expand Down
Loading

0 comments on commit 4970414

Please sign in to comment.