From 1b4e9aee456fadc3586e9ad491a5236d30bcccaa Mon Sep 17 00:00:00 2001 From: guohe Date: Fri, 25 May 2018 10:49:47 +0800 Subject: [PATCH] ctc batch inference, change im2sequence_op --- paddle/fluid/operators/im2sequence_op.cc | 17 +- paddle/fluid/operators/im2sequence_op.h | 201 ++++++++-- paddle/fluid/operators/math/im2col.cc | 78 ++-- paddle/fluid/operators/math/im2col.cu | 4 +- paddle/fluid/operators/math/im2col.h | 1 + python/paddle/fluid/layers/nn.py | 490 ++++++++++++++++++++--- 6 files changed, 664 insertions(+), 127 deletions(-) diff --git a/paddle/fluid/operators/im2sequence_op.cc b/paddle/fluid/operators/im2sequence_op.cc index 048391549dd8d..0908d31edda01 100644 --- a/paddle/fluid/operators/im2sequence_op.cc +++ b/paddle/fluid/operators/im2sequence_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/im2sequence_op.h" +#include namespace paddle { namespace operators { @@ -53,14 +54,14 @@ class Im2SequenceOp : public framework::OperatorWithKernel { class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker { public: - Im2SequenceOpMaker(OpProto* proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { + void Make() override { AddInput("X", "(Tensor) The input tensor has NCHW format." "N: batch size" "C: channels" "H: height" "W: width"); + AddInput("Image_real_size", "Image real size."); AddOutput("Out", "(LodTensor) The output data of im2sequence op,"); AddAttr>("kernels", "(vector), the " @@ -73,6 +74,13 @@ class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker { "(vector default:{0, 0, 0, 0}), the " "paddings(up_pad, left_pad, down_pad, right_pad)") .SetDefault({0, 0, 0, 0}); + AddAttr>("out_stride", + "(vector dedault:{1,1}),the out_stride " + " (out_stride_height, out_stride_width)") + .SetDefault({1, 1}); + AddAttr("is_inference", + " nor 0 is inference, 0 is train") + .SetDefault({false}); AddComment(R"DOC( This op uses kernels to scan images and converts these images to sequences. After expanding, The number of time steps are output_height * output_width @@ -147,8 +155,9 @@ class Im2SequenceGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(im2sequence, ops::Im2SequenceOp, ops::Im2SequenceOpMaker, - im2sequence_grad, ops::Im2SequenceGradOp); +REGISTER_OPERATOR(im2sequence, ops::Im2SequenceOp, ops::Im2SequenceOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(im2sequence_grad, ops::Im2SequenceGradOp); REGISTER_OP_CPU_KERNEL( im2sequence, ops::Im2SequenceKernel); diff --git a/paddle/fluid/operators/im2sequence_op.h b/paddle/fluid/operators/im2sequence_op.h index a6a83fefbc626..da721f1833028 100644 --- a/paddle/fluid/operators/im2sequence_op.h +++ b/paddle/fluid/operators/im2sequence_op.h @@ -13,7 +13,8 @@ limitations under the License. */ #pragma once - +#include +#include #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -38,8 +39,9 @@ class Im2SequenceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const Tensor* in = ctx.Input("X"); + // TODO(fuhailong): add new data layer to solve multibatch inference + const Tensor* imgRealSize = ctx.Input("Image_real_size"); LoDTensor* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); // TODO(wanghaoshuang): Add layout checker after 'set_layout' // being available for python API // PADDLE_ENFORCE_EQ(in->layout(), framework::DataLayout::kNCHW, @@ -49,41 +51,178 @@ class Im2SequenceKernel : public framework::OpKernel { int img_channels = in_dim[1]; int img_height = in_dim[2]; int img_width = in_dim[3]; - + auto imgRealSize_vec = imgRealSize->data(); + auto imgRealSize_dim = imgRealSize->dims(); auto kernels = ctx.Attr>("kernels"); auto strides = ctx.Attr>("strides"); auto paddings = ctx.Attr>("paddings"); - int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0], - paddings[2], strides[0]); - int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1], - paddings[3], strides[1]); + auto out_stride = ctx.Attr>("out_stride"); + auto is_inference = ctx.Attr("is_inference"); + if (is_inference) { + if (batch_size == 1) { + out->mutable_data(ctx.GetPlace()); + int output_height = Im2SeqOutputSize(img_height, + kernels[0], + paddings[0], + paddings[2], + strides[0]); + int output_width = Im2SeqOutputSize(img_width, + kernels[1], + paddings[1], + paddings[3], + strides[1]); + const std::vector dilations({1, 1}); + auto out_dims = out->dims(); + out->Resize({batch_size, out->numel() / batch_size}); + for (int i = 0; i < batch_size; i++) { + const Tensor src = + in->Slice(i, i + 1).Resize({img_channels, + img_height, + img_width}); + Tensor dst = out->Slice(i, i + 1).Resize({output_height, + output_width, + img_channels, + kernels[0], + kernels[1]}); - const std::vector dilations({1, 1}); + math::Im2ColFunctor f; + auto& dev_ctx = ctx.template device_context(); + f(dev_ctx, src, dilations, strides, paddings, &dst); + } + out->Resize(out_dims); + // set lod information + // TODO(wanghaoshuang): Move this to InferShape + framework::LoD lod(1); + lod[0].reserve(batch_size + 1); + int offset = 0; + lod[0].push_back(offset); + for (int i = 0; i < batch_size; ++i) { + offset += output_height * output_width; + lod[0].push_back(offset); + } + out->set_lod(lod); + } else { + std::vector imgReal_H; + std::vector imgReal_W; + for (int i = 0; i < batch_size; i++) { + int tmp_real_H = int(imgRealSize_vec[2 * i]); + int tmp_real_W = int(imgRealSize_vec[2 * i + 1]); + for (int j = 0; j < out_stride[0]; j++) { + tmp_real_H = tmp_real_H / 2 + tmp_real_H % 2; + tmp_real_W = tmp_real_W / 2 + tmp_real_W % 2; + } + imgReal_H.push_back(tmp_real_H); + imgReal_W.push_back(tmp_real_W); + } - auto out_dims = out->dims(); - out->Resize({batch_size, out->numel() / batch_size}); - for (int i = 0; i < batch_size; i++) { - const Tensor src = - in->Slice(i, i + 1).Resize({img_channels, img_height, img_width}); - Tensor dst = out->Slice(i, i + 1).Resize( - {output_height, output_width, img_channels, kernels[0], kernels[1]}); + // TODO(fuhailong): for loop to compute real output size + std::vector output_height; + std::vector output_width; + for (int i = 0; i < batch_size; i++) { + output_height.push_back(Im2SeqOutputSize(imgReal_H[i], + kernels[0], + paddings[0], + paddings[2], + strides[0])); + output_width.push_back(Im2SeqOutputSize(imgReal_W[i], + kernels[1], + paddings[1], + paddings[3], + strides[1])); + } + // TODO(fuhailong): compute dims of output + // call: out->mutable_data(ctx.GetPlace(), output_dims); + int result = 0; + for (int i = 0; i < batch_size; i++) { + result += output_height[i] * output_width[i]; + } - math::Im2ColFunctor f; - auto& dev_ctx = ctx.template device_context(); - f(dev_ctx, src, dilations, strides, paddings, &dst); - } - out->Resize(out_dims); - - // set lod information - // TODO(wanghaoshuang): Move this to InferShape - framework::LoD lod(1); - lod[0].reserve(batch_size + 1); - for (int i = 0, offset = 0; i < batch_size + 1; ++i) { - lod[0].push_back(offset); - offset += output_height * output_width; - } - out->set_lod(lod); - } + out->mutable_data({result, img_channels*kernels[0]*kernels[1]}, + ctx.GetPlace()); + // out->numel(); + const std::vector dilations({1, 1}); + // TODO(fuhailong): out_dims has two index, + // out_dims[0] and out_dims[1], + // {batchsize*output_height*output_width,channel*kernel[0],*kernel[1]}, + // multi batch ,the first place is output_height[i]*output_width[i]. + auto out_dims = out->dims(); + int offset_out = 0; + + for (int i = 0; i < batch_size; i++) { + const Tensor src = + in->Slice(i, i + 1).Resize({img_channels, + img_height, + img_width}); + // TODO(fuhailong): add image real size + Tensor dst = out->Slice(offset_out, + offset_out + output_height[i]*output_width[i]).Resize( + {output_height[i], output_width[i], + img_channels, kernels[0], kernels[1]}); + offset_out += output_height[i]*output_width[i]; + + math::Im2ColFunctor f; + // eq, kOCF cnn to rnn format + auto& dev_ctx = ctx.template device_context(); + f(dev_ctx, src, dilations, strides, paddings, &dst); + } + out->Resize(out_dims); + // set lod information + // TODO(wanghaoshuang): Move this to InferShape + framework::LoD lod(1); + lod[0].reserve(batch_size + 1); + int offset = 0; + lod[0].push_back(offset); + for (int i = 0; i < batch_size; ++i) { + offset += output_height[i] * output_width[i]; + lod[0].push_back(offset); + } + out->set_lod(lod); + } + } else { + out->mutable_data(ctx.GetPlace()); + int output_height = Im2SeqOutputSize(img_height, + kernels[0], + paddings[0], + paddings[2], + strides[0]); + int output_width = Im2SeqOutputSize(img_width, + kernels[1], + paddings[1], + paddings[3], + strides[1]); + + const std::vector dilations({1, 1}); + auto out_dims = out->dims(); + out->Resize({batch_size, out->numel() / batch_size}); + for (int i = 0; i < batch_size; i++) { + const Tensor src = + in->Slice(i, i + 1).Resize({img_channels, + img_height, + img_width}); + Tensor dst = out->Slice(i, i + 1).Resize({output_height, + output_width, + img_channels, + kernels[0], + kernels[1]}); + + math::Im2ColFunctor f; + auto& dev_ctx = ctx.template device_context(); + f(dev_ctx, src, dilations, strides, paddings, &dst); + } + out->Resize(out_dims); + // set lod information + // TODO(wanghaoshuang): Move this to InferShape + framework::LoD lod(1); + lod[0].reserve(batch_size + 1); + int offset = 0; + lod[0].push_back(offset); + for (int i = 0; i < batch_size; ++i) { + offset += output_height * output_width; + lod[0].push_back(offset); + } + out->set_lod(lod); + } + } }; template diff --git a/paddle/fluid/operators/math/im2col.cc b/paddle/fluid/operators/math/im2col.cc index 123e10586f60f..5d3f0459a6932 100644 --- a/paddle/fluid/operators/math/im2col.cc +++ b/paddle/fluid/operators/math/im2col.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/im2col.h" +#include namespace paddle { namespace operators { @@ -42,21 +43,21 @@ class Im2ColFunctordims()[3]; int col_width = col->dims()[4]; - PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - - ((dilation[0] * (filter_height - 1) + 1))) / - stride[0] + - 1, - col_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - - ((dilation[1] * (filter_width - 1) + 1))) / - stride[1] + - 1, - col_width, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - +/* PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + * ((dilation[0] * (filter_height - 1) + 1))) / + * stride[0] + + * 1, + * col_height, + * "Output_height and padding(padding_up, padding_down) are " + * "inconsistent."); + * PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + * ((dilation[1] * (filter_width - 1) + 1))) / + * stride[1] + + * 1, + * col_width, + * "Output_height and padding(padding_up, padding_down) are " + * "inconsistent."); + */ int channels_col = im_channels * filter_height * filter_width; const T* im_data = im.data(); @@ -177,22 +178,23 @@ class Im2ColFunctordims()[0]; int col_width = col->dims()[1]; - PADDLE_ENFORCE_EQ( - (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1, - col_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ( - (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1, - col_width, - "col_width and padding(padding_left, padding_right) are " - "inconsistent."); - +/* PADDLE_ENFORCE_EQ( + * (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1, + * col_height, + * "Output_height and padding(padding_up, padding_down) are " + * "inconsistent."); + * PADDLE_ENFORCE_EQ( + * (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1, + * col_width, + * "col_width and padding(padding_left, padding_right) are " + * "inconsistent."); + */ const T* im_data = im.data(); T* col_data = col->data(); for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { + std::cout << "step data: " << col_row_idx * col_col_idx << " "; for (int channel = 0; channel < im_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; ++filter_row_idx) { @@ -218,9 +220,11 @@ class Im2ColFunctor= im_width) ? static_cast(0) : im_data[im_offset]; + std::cout << im_data[im_offset] << " "; } } } + std::cout << std::endl; } } } @@ -250,17 +254,17 @@ class Col2ImFunctordata(); const T* col_data = col.data(); diff --git a/paddle/fluid/operators/math/im2col.cu b/paddle/fluid/operators/math/im2col.cu index f41c78140fb60..eecb233d22cea 100644 --- a/paddle/fluid/operators/math/im2col.cu +++ b/paddle/fluid/operators/math/im2col.cu @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "paddle/fluid/operators/math/im2col.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/im2col.h b/paddle/fluid/operators/math/im2col.h index 451ec9d53498c..26d94e0f2e616 100644 --- a/paddle/fluid/operators/math/im2col.h +++ b/paddle/fluid/operators/math/im2col.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/device_context.h" diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bf161d6618b10..7b8b4bdd4e8fd 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -73,6 +73,10 @@ 'smooth_l1', 'one_hot', 'autoincreased_step_counter', + 'reshape', + 'lod_reset', + 'lrn', + 'pad', ] @@ -81,6 +85,7 @@ def fc(input, num_flatten_dims=1, param_attr=None, bias_attr=None, + use_mkldnn=False, act=None, name=None): """ @@ -128,6 +133,8 @@ def fc(input, bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias of this layer. If it is set to None, no bias will be added to the output units. act (str, default None): Activation to be applied to the output of this layer. + use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn + library is installed. Default: False name (str, default None): The name of this layer. Returns: @@ -148,35 +155,64 @@ def fc(input, dtype = helper.input_dtype() mul_results = [] - for input_var, param_attr in helper.iter_inputs_and_params(): - input_shape = input_var.shape + if use_mkldnn: + tmp = helper.create_tmp_variable(dtype) + input_shape = input.shape param_shape = [ reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) ] + [size] w = helper.create_parameter( - attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) - tmp = helper.create_tmp_variable(dtype) + attr=helper.param_attr, + shape=param_shape, + dtype=dtype, + is_bias=False) + if bias_attr is None or bias_attr is False: + bias_attr = False + else: + bias_attr = True helper.append_op( - type="mul", - inputs={"X": input_var, - "Y": w}, + type="fc", + inputs={"Input": input, + "W": w}, outputs={"Out": tmp}, - attrs={"x_num_col_dims": num_flatten_dims, - "y_num_col_dims": 1}) - mul_results.append(tmp) - - # sum - if len(mul_results) == 1: - pre_bias = mul_results[0] + attrs={"use_mkldnn": use_mkldnn, + "bias_attr": bias_attr}) + return helper.append_activation(tmp) else: - pre_bias = helper.create_tmp_variable(dtype) - helper.append_op( - type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias}) - # add bias - pre_activation = helper.append_bias_op(pre_bias, dim_start=num_flatten_dims) - # add activation - return helper.append_activation(pre_activation) + for input_var, param_attr in helper.iter_inputs_and_params(): + input_shape = input_var.shape + param_shape = [ + reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) + ] + [size] + + w = helper.create_parameter( + attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) + tmp = helper.create_tmp_variable(dtype) + helper.append_op( + type="mul", + inputs={"X": input_var, + "Y": w}, + outputs={"Out": tmp}, + attrs={ + "x_num_col_dims": num_flatten_dims, + "y_num_col_dims": 1, + }) + mul_results.append(tmp) + + if len(mul_results) == 1: + pre_bias = mul_results[0] + else: + pre_bias = helper.create_tmp_variable(dtype) + helper.append_op( + type="sum", + inputs={"X": mul_results}, + outputs={"Out": pre_bias}) + # add bias + pre_activation = helper.append_bias_op( + pre_bias, dim_start=num_flatten_dims) + # add activation + return helper.append_activation(pre_activation) def embedding(input, @@ -1116,12 +1152,14 @@ def conv2d(input, filter_size, stride=1, padding=0, + dilation=1, groups=None, param_attr=None, bias_attr=None, use_cudnn=True, use_mkldnn=False, - act=None): + act=None, + name=None): """ **Convlution2D Layer** @@ -1182,6 +1220,9 @@ def conv2d(input, padding(int|tuple): The padding size. If padding is a tuple, it must contain two integers, (padding_H, padding_W). Otherwise, the padding_H = padding_W = padding. Default: padding = 0. + dilation(int|tuple): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: dilation = 1. groups(int): The groups number of the Conv2d Layer. According to grouped convolution in Alex Krizhevsky's Deep CNN paper: when group=2, the first half of the filters is only connected to the first half @@ -1192,6 +1233,8 @@ def conv2d(input, use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn library is installed. Default: True act(str): Activation type. Default: None + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. Returns: Variable: The tensor variable storing the convolution and \ @@ -1232,6 +1275,7 @@ def conv2d(input, filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') stride = utils.convert_to_list(stride, 2, 'stride') padding = utils.convert_to_list(padding, 2, 'padding') + dilation = utils.convert_to_list(dilation, 2, 'dilation') if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") @@ -1261,6 +1305,7 @@ def _get_default_param_initializer(): attrs={ 'strides': stride, 'paddings': padding, + 'dilations': dilation, 'groups': groups, 'use_cudnn': use_cudnn, 'use_mkldnn': use_mkldnn @@ -1468,6 +1513,7 @@ def batch_norm(input, param_attr=None, bias_attr=None, data_layout='NCHW', + in_place=False, name=None, moving_mean_name=None, moving_variance_name=None): @@ -1523,7 +1569,7 @@ def batch_norm(input, saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - batch_norm_out = helper.create_tmp_variable(dtype) + batch_norm_out = input if in_place else helper.create_tmp_variable(dtype) helper.append_op( type="batch_norm", @@ -1669,7 +1715,9 @@ def conv2d_transpose(input, stride=1, dilation=1, param_attr=None, + bias_attr=None, use_cudnn=True, + act=None, name=None): """ **Convlution2D transpose layer** @@ -1738,8 +1786,10 @@ def conv2d_transpose(input, dilation_H = dilation_W = dilation. Default: dilation = 1. param_attr(ParamAttr): The parameters to the Conv2d_transpose Layer. Default: None + bias_attr(ParamAttr): Bias parameter for the Conv2d layer. Default: None use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn library is installed. Default: True + act(str): Activation type. Default: None name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -1792,12 +1842,12 @@ def conv2d_transpose(input, img_filter = helper.create_parameter( dtype=input.dtype, shape=filter_shape, attr=helper.param_attr) - out = helper.create_tmp_variable(dtype=input.dtype) + pre_bias = helper.create_tmp_variable(dtype=input.dtype) helper.append_op( type='conv2d_transpose', inputs={'Input': [input], 'Filter': [img_filter]}, - outputs={'Output': out}, + outputs={'Output': pre_bias}, attrs={ 'strides': stride, 'paddings': padding, @@ -1805,55 +1855,57 @@ def conv2d_transpose(input, 'use_cudnn': use_cudnn }) + pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + out = helper.append_activation(pre_act) return out -def sequence_expand(x, y, name=None): +def sequence_expand(x, y, ref_level=-1, name=None): """Sequence Expand Layer. This layer will expand the input variable **x** - according to LoD information of **y**. And the following examples will - explain how sequence_expand works: + according to specified level lod of **y**. Please note that lod level of + **x** is at most 1 and rank of **x** is at least 2. When rank of **x** + is greater than 2, then it would be viewed as a 2-D tensor. + Following examples will explain how sequence_expand works: .. code-block:: text * Case 1 x is a LoDTensor: - x.lod = [[0, 2, 3], - [0, 1, 3, 4]] - x.data = [a, b, c, d] + x.lod = [[0, 2, 4]] + x.data = [[a], [b], [c], [d]] x.dims = [4, 1] y is a LoDTensor: y.lod = [[0, 2, 4], [0, 3, 6, 7, 8]] - with condition len(y.lod[-1]) - 1 == x.dims[0] + ref_level: 0 - then output is a 2-level LoDTensor: - out.lod = [[0, 2, 4], - [0, 3, 6, 7, 8]] - out.data = [a, a, a, b, b, b, c, d] + then output is a 1-level LoDTensor: + out.lod = [[0, 2, 4, 6, 8]] + out.data = [[a], [b], [a], [b], [c], [d], [c], [d]] out.dims = [8, 1] * Case 2 x is a Tensor: - x.data = [a, b, c] + x.data = [[a], [b], [c]] x.dims = [3, 1] y is a LoDTensor: - y.lod = [[0, 2, 3, 6]] - - with condition len(y.lod[-1]) - 1 == x.dims[0] + y.lod = [[0, 2, 2, 5]] - then output is a 1-level LoDTensor: - out.lod = [[0, 2, 3, 6]] - out.data = [a, a, b, c, c, c] - out.dims = [6, 1] + ref_level: -1 + then output is a Tensor: + out.data = [[a], [a], [c], [c], [c]] + out.dims = [5, 1] Args: x (Variable): The input variable which is a Tensor or LoDTensor. y (Variable): The input variable which is a LoDTensor. + ref_level (int): Lod level of `y` to be referred by `x`. If set to -1, + refer the last level of lod. name(str|None): A name for this layer(optional). If set None, the layer - will be named automatically. + will be named automatically. Returns: Variable: The expanded variable which is a LoDTensor. @@ -1864,14 +1916,17 @@ def sequence_expand(x, y, name=None): x = fluid.layers.data(name='x', shape=[10], dtype='float32') y = fluid.layers.data(name='y', shape=[10, 20], dtype='float32', lod_level=1) - out = layers.sequence_expand(x=x, y=y) + out = layers.sequence_expand(x=x, y=y, ref_level=0) """ helper = LayerHelper('sequence_expand', input=x, **locals()) dtype = helper.input_dtype() tmp = helper.create_tmp_variable(dtype) helper.append_op( - type='sequence_expand', inputs={'X': x, - 'Y': y}, outputs={'Out': tmp}) + type='sequence_expand', + inputs={'X': x, + 'Y': y}, + outputs={'Out': tmp}, + attrs={'ref_level': ref_level}) return tmp @@ -2225,7 +2280,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): keep_dim (bool|False): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. - name(str|None): A name for this layer(optional). If set None, the + name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. Returns: @@ -2241,7 +2296,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_prod(x) # [0.0002268] fluid.layers.reduce_prod(x, dim=0) # [0.02, 0.06, 0.3, 0.63] fluid.layers.reduce_prod(x, dim=-1) # [0.027, 0.0084] - fluid.layers.reduce_prod(x, dim=1, + fluid.layers.reduce_prod(x, dim=1, keep_dim=True) # [[0.027], [0.0084]] """ helper = LayerHelper('reduce_prod', **locals()) @@ -2875,8 +2930,9 @@ def transpose(x, perm, name=None): attrs={'axis': perm}) return out - -def im2sequence(input, filter_size=1, stride=1, padding=0, name=None): +# fuhailong add new input : Image_real_size +def im2sequence(input, inputImgSize, filter_size=1, stride=1, padding=0, + out_stride=1, is_inference=False, name=None): """ Extracts image patches from the input tensor to form a tensor of shape {input.batch_size * output_height * output_width, filter_size_H * @@ -2987,17 +3043,21 @@ def im2sequence(input, filter_size=1, stride=1, padding=0, name=None): if len(padding) == 2: padding.append(padding[0]) padding.append(padding[1]) - + if isinstance(out_stride, int): + out_stride = [out_stride, out_stride] helper = LayerHelper('im2sequence', **locals()) out = helper.create_tmp_variable(dtype=helper.input_dtype()) helper.append_op( type='im2sequence', - inputs={'X': input}, + inputs={'X': input, + 'Image_real_size': inputImgSize}, outputs={'Out': out}, attrs={ 'kernels': filter_size, 'strides': stride, 'paddings': padding, + 'out_stride': out_stride, + 'is_inference': is_inference }) return out @@ -3240,6 +3300,8 @@ def one_hot(input, depth): The one-hot tensor or LodTensor, same as input. Examples: + .. code-block:: python + X is a LoDTensor: X.lod = [[0, 1, 4]] X.shape = [4, 1] @@ -3292,3 +3354,323 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): counter.stop_gradient = True return counter + + +def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): + """ + Gives a new shape to the input Tensor without changing its data. + + The target shape can be given by :attr:`shape` or :attr:`actual_shape`. + :attr:`shape` is a list of integer while :attr:`actual_shape` is a tensor + variable. :attr:`actual_shape` has a higher priority than :attr:`shape` + if it is provided, while :attr:`shape` still should be set correctly to + gurantee shape inference in compile-time. + + Some tricks exist when specifying the target shape. + + 1. -1 means the value of this dimension is inferred from the total element + number of x and remaining dimensions. Thus one and only one dimension can + be set -1. + + 2. 0 means the actual dimension value is going to be copied from the + corresponding dimension of x. The indice of 0s in shape can not exceed + Rank(X). + + Here are some examples to explain it. + + 1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape + is [6, 8], the reshape operator will transform x into a 2-D tensor with + shape [6, 8] and leaving x's data unchanged. + + 2. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape + specified is [2, 3, -1, 2], the reshape operator will transform x into a + 4-D tensor with shape [2, 3, 4, 2] and leaving x's data unchanged. In this + case, one dimension of the target shape is set to -1, the value of this + dimension is inferred from the total element number of x and remaining + dimensions. + + 3. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape + is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor + with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case, + besides -1, 0 means the actual dimension value is going to be copied from + the corresponding dimension of x. + + Args: + input(variable): The input tensor. + shape(list): The new shape. At most one dimension of the new shape can + be -1. + actual_shape(variable): An optional input. If provided, reshape + according to this given shape rather than + :attr:`shape` specifying shape. That is to + say :attr:`actual_shape` has a higher priority + than :attr:`shape`. + act (str): The non-linear activation to be applied to output variable. + inplace(bool): If this flag is set true, a new output tensor is created + whose data is copied from input x, otherwise the output + shares data with input without copying. + + Returns(variable): The output tensor. + + Examples: + .. code-block:: python + + data = fluid.layers.data( + name='data', shape=[2, 4, 6], dtype='float32') + reshaped = fluid.layers.reshape( + x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True) + """ + + if not (isinstance(shape, list) or isinstance(shape, tuple)): + raise ValueError("Input shape must be a python lsit or tuple.") + + # Validate the shape + unk_dim_idx = -1 + for dim_idx, dim_size in enumerate(shape): + if dim_size == -1: + assert unk_dim_idx == -1, ( + "Only one dimension in shape can be unknown.") + unk_dim_idx = dim_idx + elif dim_size == 0: + assert dim_idx < len(x.shape), ( + "The indice of 0s in shape can not exceed Rank(X).") + else: + assert dim_size > 0, ( + "Each dimension size given in shape must not be negtive " + "except one unknown dimension.") + + helper = LayerHelper("reshape", **locals()) + reshaped = helper.create_tmp_variable(dtype=x.dtype) + helper.append_op( + type="reshape", + inputs={"X": x, + "Shape": actual_shape} + if isinstance(actual_shape, Variable) else {"X": x}, + attrs={"shape": shape, + "inplace": inplace}, + outputs={"Out": reshaped}) + + return helper.append_activation(reshaped) + + +def lod_reset(x, y=None, target_lod=None): + """ + LoD Reset Operator. Set LoD of **x** to a new one specified by **y** or + **target_lod**. When **y** provided, **y.lod** would be considered as target + LoD first, otherwise **y.data** would be considered as target LoD. If **y** + is not provided, target LoD should be specified by **target_lod**. + If target LoD is specified by **Y.data** or **target_lod**, only one level + LoD is supported. + + .. code-block:: text + + * Example 1: + + Given a 1-level LoDTensor x: + x.lod = [[ 0, 2, 5 6 ]] + x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + x.dims = [6, 1] + + target_lod: [0, 4, 6] + + then we get a 1-level LoDTensor: + out.lod = [[ 0, 4, 6 ]] + out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + out.dims = [6, 1] + + * Example 2: + + Given a 1-level LoDTensor x: + x.lod = [[ 0, 2, 5 6 ]] + x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + x.dims = [6, 1] + + y is a Tensor: + y.data = [[0, 2, 6]] + y.dims = [1, 3] + + then we get a 1-level LoDTensor: + out.lod = [[ 0, 2, 6 ]] + out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + out.dims = [6, 1] + + * Example 3: + + Given a 1-level LoDTensor x: + x.lod = [[ 0, 2, 5 6 ]] + x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + x.dims = [6, 1] + + y is a 2-level LoDTensor: + y.lod = [[0, 2, 4], [0, 2, 5, 6]] + y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]] + y.dims = [6, 1] + + then we get a 2-level LoDTensor: + out.lod = [[0, 2, 4], [0, 2, 5, 6]] + out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + out.dims = [6, 1] + + Args: + x (Variable): Input variable which could be a Tensor or LodTensor. + y (Variable|None): If provided, output's LoD would be derived from y. + target_lod (list|tuple|None): One level LoD which should be considered + as target LoD when y not provided. + + Returns: + Variable: Output variable with LoD specified by this operator. + + Raises: + ValueError: If y and target_lod are both None. + + Examples: + .. code-block:: python + + x = layers.data(name='x', shape=[10]) + y = layers.data(name='y', shape=[10, 20], lod_level=2) + out = layers.lod_reset(x=x, y=y) + """ + helper = LayerHelper("lod_reset", **locals()) + out = helper.create_tmp_variable(dtype=x.dtype) + if y is not None: + helper.append_op( + type="lod_reset", inputs={'X': x, + 'Y': y}, outputs={'Out': out}) + elif target_lod is not None: + helper.append_op( + type="lod_reset", + inputs={'X': x}, + attrs={'target_lod': target_lod}, + outputs={'Out': out}) + else: + raise ValueError("y and target_lod should not be both None.") + + return out + + +def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): + """ + Local Response Normalization Layer. This layer performs a type of + "lateral inhibition" by normalizing over local input regions. + + The formula is as follows: + + .. math:: + + Output(i, x, y) = Input(i, x, y) / \left( + k + \alpha \sum\limits^{\min(C, c + n/2)}_{j = \max(0, c - n/2)} + (Input(j, x, y))^2 \right)^{\beta} + + In the above equation: + + * :math:`n`: The number of channels to sum over. + * :math:`k`: The offset (avoid being divided by 0). + * :math:`alpha`: The scaling parameter. + * :math:`beta`: The exponent parameter. + + Refer to `ImageNet Classification with Deep Convolutional Neural Networks + `_ + + Args: + input (Variable): The input tensor of this layer, and the dimension of input tensor must be 4. + n (int, default 5): The number of channels to sum over. + k (float, default 1.0): An offset (usually positive to avoid dividing by 0). + alpha (float, default 1e-4): The scaling parameter. + beta (float, default 0.75): The exponent. + name (str, default None): A name for this operation. + + Raises: + ValueError: If rank of the input tensor is not 4. + + Returns: + A tensor variable storing the transformation result. + + Examples: + .. code-block:: python + + data = fluid.layers.data(name="data", shape=[3, 112, 112], dtype="float32") + lrn = fluid.layers.lrn(input=data) + """ + helper = LayerHelper('lrn', **locals()) + dtype = helper.input_dtype() + input_shape = input.shape + dims = len(input_shape) + + if dims != 4: + raise ValueError( + "dims of input must be 4(not %d), and it's order must be NCHW" % + (dims)) + + mid_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) + lrn_out = helper.create_tmp_variable(dtype) + helper.append_op( + type="lrn", + inputs={"X": input}, + outputs={ + "Out": lrn_out, + "MidOut": mid_out, + }, + attrs={"n": n, + "k": k, + "alpha": alpha, + "beta": beta}) + + return lrn_out + + +def pad(x, paddings, pad_value=0., name=None): + """ + Pads a tensor with a constant value given by :attr:`pad_value`, and the + padded width is specified by :attr:`paddings`. + + Specifically, the number of values padded before the contents of :attr:`x` + in dimension :attr:`i` is indicated by :attr:`paddings[i]`, and the number + of values padded after the contents of :attr:`x` in dimension :attr:`i` is + indicated by :attr:`paddings[i+1]`. + + See below for an example. + + .. code-block:: text + + Given: + x = [[1, 2], [3, 4]] + + paddings = [0, 1, 1, 2] + + pad_value = 0 + + Return: + + out = [[0, 1, 2, 0, 0] + [0, 3, 4, 0, 0] + [0, 0, 0, 0, 0]] + + Args: + x (Variable): The input tensor variable. + paddings (list): A list of integers. Its elements specify the padded + width before and after for each dimension in turn. + The length of :attr:paddings must be + :math:`rank(x) \\times 2`. + pad_value (float): The constant value used to pad. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The padded tensor variable. + + Examples: + .. code-block:: python + + # x is a rank 2 tensor variable. + out = fluid.layers.pad( + x=x, paddings=[0, 1, 1, 2], pad_value=0.) + """ + helper = LayerHelper('pad', input=x, **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + helper.append_op( + type='pad', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'paddings': paddings, + 'pad_value': float(pad_value)}) + return out