Skip to content

Commit

Permalink
adding missing shape propagation for clamp operation
Browse files Browse the repository at this point in the history
  • Loading branch information
Viktor Gyenes committed Oct 19, 2018
1 parent ad34b69 commit 429b392
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions parser/cpp/common/propagation.h
Expand Up @@ -115,6 +115,7 @@ namespace nnef
std::make_pair("max_reduce", PropagationGroup::Reduce),
std::make_pair("mean_reduce", PropagationGroup::Reduce),
std::make_pair("argmax_reduce", PropagationGroup::Reduce),
std::make_pair("argmin_reduce", PropagationGroup::Reduce),
std::make_pair("moments", PropagationGroup::Reduce),

std::make_pair("nearest_downsample", PropagationGroup::DownSample),
Expand Down Expand Up @@ -152,6 +153,7 @@ namespace nnef
std::make_pair("softmax", PropagationGroup::Unique),
std::make_pair("copy_n", PropagationGroup::Unique),
std::make_pair("add_n", PropagationGroup::Unique),
std::make_pair("clamp", PropagationGroup::Unique),
};

auto it = operationGroups.find(name);
Expand Down Expand Up @@ -266,7 +268,7 @@ namespace nnef
}
else if ( op == "select" )
{
propagateShapesSelect(proto, args, shapes);
propagateShapesTernary(proto, args, shapes);
}
else if ( op == "matmul" )
{
Expand All @@ -292,6 +294,10 @@ namespace nnef
{
propagateShapesAddN(proto, args, shapes);
}
else if ( op == "clamp" )
{
propagateShapesTernary(proto, args, shapes);
}
else
{
assert(false);
Expand Down Expand Up @@ -1141,32 +1147,32 @@ namespace nnef
setShape(output, shapes, outputShape);
}

void propagateShapesSelect( const Prototype& proto, const Dictionary<Value>& args, Dictionary<Shape>& shapes )
void propagateShapesTernary( const Prototype& proto, const Dictionary<Value>& args, Dictionary<Shape>& shapes )
{
auto& condition = args["condition"];
auto& trueValue = args["true_value"];
auto& falseValue = args["false_value"];
auto& output = args["output"];
auto& arg0 = args[proto.param(0).name()];
auto& arg1 = args[proto.param(1).name()];
auto& arg2 = args[proto.param(2).name()];
auto& output = args[proto.result(0).name()];

auto& conditionShape = getShape(condition, shapes);
auto& trueShape = getShape(trueValue, shapes);
auto& falseShape = getShape(falseValue, shapes);
auto& shape0 = getShape(arg0, shapes);
auto& shape1 = getShape(arg1, shapes);
auto& shape2 = getShape(arg2, shapes);

if ( !isBroadcastCompatible(trueShape, falseShape) )
if ( !isBroadcastCompatible(shape1, shape2) )
{
throw Error("incompatible tensor shapes in select operation (%s vs %s)",
trueShape.toString().c_str(), falseShape.toString().c_str());
throw Error("incompatible tensor shapes in ternary operation (%s vs %s)",
shape1.toString().c_str(), shape2.toString().c_str());
}

Shape valueShape = broadcastShape(trueShape, falseShape);
Shape shape12 = broadcastShape(shape1, shape2);

if ( !isBroadcastCompatible(conditionShape, valueShape) )
if ( !isBroadcastCompatible(shape0, shape12) )
{
throw Error("condition shape incompatible with result shape in select operation (%s vs %s)",
conditionShape.toString().c_str(), valueShape.toString().c_str());
throw Error("incompatible tensor shapes in ternary operation (%s vs %s)",
shape0.toString().c_str(), shape12.toString().c_str());
}

Shape outputShape = broadcastShape(conditionShape, valueShape);
Shape outputShape = broadcastShape(shape0, shape12);

setShape(output, shapes, outputShape);
}
Expand Down

0 comments on commit 429b392

Please sign in to comment.