Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPP implementation of L2Norm and LRN ops #1157

Merged
merged 12 commits into from
Jun 22, 2018
35 changes: 35 additions & 0 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,41 @@ struct NMSParam : public dmlc::Parameter<NMSParam> {
}
};

struct LRNParam : public dmlc::Parameter<LRNParam> {
int size;
int axis;
float alpha;
float beta;
float bias;

DMLC_DECLARE_PARAMETER(LRNParam) {
DMLC_DECLARE_FIELD(size)
.describe("The size of the local region to be considered for normalization.");
DMLC_DECLARE_FIELD(axis)
.describe("input data layout channel axis");
DMLC_DECLARE_FIELD(alpha)
.describe("The scaling parameter.");
DMLC_DECLARE_FIELD(beta)
.describe("The exponent parameter.");
DMLC_DECLARE_FIELD(bias)
.describe("The offset parameter.");
}
// constants
static const constexpr int kData = 0;
};

struct L2NormalizeParam : public dmlc::Parameter<L2NormalizeParam> {
float eps;
Tuple<int> axis;

DMLC_DECLARE_PARAMETER(L2NormalizeParam) {
DMLC_DECLARE_FIELD(eps)
.describe("float type epsilon value.");
DMLC_DECLARE_FIELD(axis)
.describe("axis over the normalization applied");
}
};

} // namespace top
} // namespace nnvm

Expand Down
33 changes: 33 additions & 0 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,36 @@ def schedule_upsampling(_, outs, target):
return topi.generic.schedule_injective(outs)

reg.register_pattern("upsampling", OpPattern.INJECTIVE)

@reg.register_compute("lrn")
def compute_lrn(attrs, inputs, _):
"""Compute definition of lrn"""
size = attrs.get_int("size")
axis = attrs.get_int("axis")
alpha = attrs.get_float("alpha")
beta = attrs.get_float("beta")
bias = attrs.get_float("bias")
return topi.nn.lrn(inputs[0], size, axis, alpha, beta, bias)

@reg.register_schedule("lrn")
def schedule_lrn(attrs, outs, target):
"""Schedule definition of lrn"""
with tvm.target.create(target):
return topi.generic.schedule_lrn(outs)

reg.register_pattern("lrn", OpPattern.OPAQUE)

@reg.register_compute("l2_normalize")
def compute_l2_normalize(attrs, inputs, _):
"""Compute definition of l2 normalize"""
eps = attrs.get_float("eps")
axis = attrs.get_int_tuple("axis")
return topi.nn.l2_normalize(inputs[0], eps, axis)

@reg.register_schedule("l2_normalize")
def schedule_l2_normalize(attrs, outs, target):
"""Schedule definition of l2 normalize"""
with tvm.target.create(target):
return topi.generic.schedule_l2_normalize(outs)

reg.register_pattern("l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
47 changes: 47 additions & 0 deletions nnvm/src/top/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -712,5 +712,52 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
})
.set_support_level(1);

DMLC_REGISTER_PARAMETER(LRNParam);

inline bool LRNInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
TShape dshape = (*in_shape)[0];
TShape oshape = dshape;

NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}

NNVM_REGISTER_OP(lrn)
.describe(R"code(LRN layer)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
.set_attr_parser(ParamParser<LRNParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<LRNParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", LRNInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_support_level(1);

DMLC_REGISTER_PARAMETER(L2NormalizeParam);

inline bool L2NormalizeInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
TShape dshape = (*in_shape)[0];
TShape oshape = dshape;

NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}

NNVM_REGISTER_OP(l2_normalize)
.describe(R"code(L2NORMALIZE layer)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
.set_attr_parser(ParamParser<L2NormalizeParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<L2NormalizeParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", L2NormalizeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_support_level(1);

} // namespace top
} // namespace nnvm
62 changes: 61 additions & 1 deletion nnvm/tests/python/compiler/test_top_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import nnvm.compiler
from nnvm.testing.config import ctx_list


def helper(symbol, inputs, dtype,
np_forward, np_backward=None, need_input=True, need_head_grads=True):
ishapes = {}
Expand Down Expand Up @@ -365,6 +364,65 @@ def forward(x):
inputs = [('x', (1, 3, 28, 28), x)]
helper(y, inputs, dtype, forward)

def verify_lrn(ishape, size, axis, bias, alpha, beta):
x = sym.Variable("x")
y = sym.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)
dtype = "float32"
x_np = np.random.uniform(size=ishape).astype(dtype)

for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty(ishape))
out_np = topi.testing.lrn_python(x_np, size, axis, bias, alpha, beta)
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)

#Checking LRN op followed by elementwise op relu
z = sym.relu(y)
x_np = np.random.uniform(low=-10.0, high=10.0, size=ishape).astype(dtype)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(z, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty(ishape))
out_np = topi.testing.lrn_python(x_np, size, axis, bias, alpha, beta)
out_np = (out_np > 0) * out_np
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)

def verify_l2_normalize(ishape, eps, axis):
x = sym.Variable("x")
y = sym.l2_normalize(x, eps=eps, axis=axis)
dtype = "float32"
x_np = np.random.uniform(size=ishape).astype(dtype)

for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty(ishape))
out_np = topi.testing.l2_normalize_python(x_np, eps, axis)
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)

#Checking L2 normalization op followed by elementwise op relu
z = sym.relu(y)
x_np = np.random.uniform(low=-10.0, high=10.0, size=ishape).astype(dtype)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(z, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty(ishape))
out_np = topi.testing.l2_normalize_python(x_np, eps, axis)
out_np = (out_np > 0) * out_np
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)

def test_lrn():
verify_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75)

def test_l2_normalize():
verify_l2_normalize((1, 3, 20, 20), 0.001, (1,))
verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2))

if __name__ == "__main__":
test_split()
Expand All @@ -384,3 +442,5 @@ def forward(x):
test_softmax()
test_squeeze()
test_pad()
test_lrn()
test_l2_normalize()
106 changes: 106 additions & 0 deletions topi/include/topi/cuda/normalization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*!
* Copyright (c) 2018 by Contributors
* \file cuda/normalization.h
* \brief CUDA schedule for LRN and l2 normalization operations
*/
#ifndef TOPI_CUDA_NORMALIZATION_H_
#define TOPI_CUDA_NORMALIZATION_H_

#include "tvm/tvm.h"
#include "tvm/build_module.h"
#include "topi/tags.h"

namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Create a CUDA schedule for LRN
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_lrn(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
Schedule s = create_schedule(out_ops);
int num_thread = 64;
IterVar block_x = tvm::thread_axis(Range(), "blockIdx.x");
IterVar thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
Tensor lrn = outs[0];
Tensor sqr_sum_up = lrn->op->InputTensors()[1];
Tensor sqr_sum = sqr_sum_up->op->InputTensors()[0];
Tensor set_pad = sqr_sum->op->InputTensors()[0];
s[set_pad].bind(set_pad->op.as<ComputeOpNode>()->axis[0], block_x);
IterVar rxk = sqr_sum->op.as<ComputeOpNode>()->reduce_axis[0];
IterVar xko, xki;
s[sqr_sum].split(rxk, num_thread, &xko, &xki);
Tensor srf = s.rfactor(sqr_sum, xki)[0];
s[sqr_sum].bind(s[sqr_sum]->op.as<ComputeOpNode>()->axis[0], block_x);
s[sqr_sum].bind(s[sqr_sum]->op.as<ComputeOpNode>()->reduce_axis[0], thread_x);
s[srf].compute_at(s[sqr_sum], s[sqr_sum]->op.as<ComputeOpNode>()->reduce_axis[0]);
s[sqr_sum_up].bind(sqr_sum_up->op.as<ComputeOpNode>()->axis[0], block_x);
IterVar xto, xti;
s[lrn].split_by_nparts(lrn->op.as<ComputeOpNode>()->axis[1], num_thread, &xto, &xti);
s[lrn].bind(lrn->op.as<ComputeOpNode>()->axis[0], block_x);
s[lrn].bind(xto, thread_x);

return s;
}

/*!
* \brief Create a CUDA schedule for L2 normalization
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_l2_normalize(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
Schedule s = create_schedule(out_ops);

std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_injective(op->tag) || op->tag == "l2_normalize") {
if (!detail::contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
}
} else if (op->tag == "comm_reduce") {
ScheduleReduce(target, op, s, false);
for (auto tensor : op->InputTensors()) {
traverse(tensor->op);
}
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};

traverse(outs[0]->op);
int num_thread = 64;
Tensor l2_normalize = outs[0];
IterVar block_x = tvm::thread_axis(Range(), "blockIdx.x");
IterVar thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
IterVar xto, xti;
s[l2_normalize].split_by_nparts(l2_normalize->op.as<ComputeOpNode>()->axis[1],
num_thread, &xto, &xti);
s[l2_normalize].bind(l2_normalize->op.as<ComputeOpNode>()->axis[0], block_x);
s[l2_normalize].bind(xto, thread_x);
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_NORMALIZATION_H_

46 changes: 46 additions & 0 deletions topi/include/topi/nn/l2_normalize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*!
* Copyright (c) 2018 by Contributors
* \brief l2 normalization op constructions
* \file nn/l2_normalize.h
*/
#ifndef TOPI_NN_L2_NORMALIZE_H_
#define TOPI_NN_L2_NORMALIZE_H_

#include <string>
#include <algorithm>
#include "topi/tags.h"
#include "tvm/tvm.h"
namespace topi {
namespace nn {
using namespace tvm;

/*!
* \brief L2 normalization inference operator
*
* \param data The input tensor. 4-D with shape [batch, channel, height, width]
* \param eps Epsilon to prevent div by 0
* \param axis Axes over the normalization applied
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the l2 normalization operation
*/
inline Tensor l2_normalize(const Tensor& data,
float eps,
const Array<Expr>& axis,
std::string name = "tensor",
std::string tag = "l2_normalize") {
CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input";
auto input_shape = data->shape;
Tensor dot_value = pow(data, static_cast<float>(2.0));
Tensor sum_value = topi::sum(dot_value, axis, true);
Tensor expand_sum = topi::broadcast_to(sum_value, input_shape);
return topi::broadcast_div(data,
topi::sqrt(tvm::compute(expand_sum->shape,
[&](const Array<Var>& i){
return (max(expand_sum(i), eps));
}, name = name, tag = tag)));
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_L2_NORMALIZE_H_