Skip to content
Permalink
Browse files

[backend][nvdla] Add support for Identity ONNX operator

  • Loading branch information...
Po Yen Chen
Po Yen Chen committed Sep 12, 2019
1 parent 89ad4b3 commit c5e2e9122610f604621ecec18ad95b84661ac2ac
@@ -303,6 +303,7 @@ nobase_nodist_include_HEADERS = \
onnc/Transforms/BuildInitializers.h \
onnc/Transforms/BuildInputOperators.h \
onnc/Transforms/Optimizations/DivideGlobalAPIntoAPs.h \
onnc/Transforms/Optimizations/EliminateIdentity.h \
onnc/Transforms/Optimizations/ExpandBatchNormalization.h \
onnc/Transforms/Optimizations/OptimizationsUtils.h \
onnc/Transforms/Optimizations/OptimizationOptions.h \
@@ -0,0 +1,30 @@
//===- EliminateIdentity.h ------------------------------------------------===//
//
// The ONNC Project
//
// See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
#ifndef ONNC_ELIMINATE_IDENTITY
#define ONNC_ELIMINATE_IDENTITY
#include <onnc/Core/CustomPass.h>

namespace onnc {

/** \class EliminateIdentity
* \brief Remove identity nodes in the graph
*/
class EliminateIdentity: public CustomPass<EliminateIdentity>
{
public:
EliminateIdentity() = default;

ReturnType runOnModule(Module& pModule) override;

private:
ReturnType runOnComputeGraph(ComputeGraph& pCG);
};

} // namespace of onnc

#endif // ONNC_ELIMINATE_IDENTITY
@@ -12,6 +12,7 @@
#include <onnc/Support/Enum.h>
#include <onnc/Support/TypeTraits.h>
#include <onnc/Transforms/Optimizations/DivideGlobalAPIntoAPs.h>
#include <onnc/Transforms/Optimizations/EliminateIdentity.h>
#include <onnc/Transforms/Optimizations/ExpandBatchNormalization.h>
#include <onnc/Transforms/Optimizations/PropagateConstWithDiffShape.h>
#include <onnc/Transforms/Optimizations/ReplaceGemmByConv.h>
@@ -27,6 +28,7 @@ namespace onnc {
enum class OptimizationOption : unsigned
{
divide_globalap_into_aps,
eliminate_identity,
propagate_const_with_diff_shape,
expand_batch_normalization,
replace_gemm_by_conv,
@@ -84,6 +86,9 @@ class OptimizationOptions
case OptimizationOption::divide_globalap_into_aps:
passManager.add<DivideGlobalAPIntoAPs>();
break;
case OptimizationOption::eliminate_identity:
passManager.add<EliminateIdentity>();
break;
case OptimizationOption::propagate_const_with_diff_shape:
passManager.add<PropagateConstWithDiffShape>();
break;
@@ -147,6 +147,7 @@ ONNC_SOURCES = \
Transforms/BuildOutputOperators.cpp \
Transforms/OnnxOptPass.cpp \
Transforms/Optimizations/DivideGlobalAPIntoAPs.cpp \
Transforms/Optimizations/EliminateIdentity.cpp \
Transforms/Optimizations/ExpandBatchNormalization.cpp \
Transforms/Optimizations/OptimizationsUtils.cpp \
Transforms/Optimizations/PropagateConstWithDiffShape.cpp \
@@ -138,6 +138,7 @@ void NvDlaBackend::addOnncIrOptimization(PassManager& passManager, OptimizationO
options.defaultEnable(OptimizationOption::propagate_const_with_diff_shape);
options.defaultEnable(OptimizationOption::expand_batch_normalization);
options.defaultEnable(OptimizationOption::replace_gemm_by_conv);
options.defaultEnable(OptimizationOption::eliminate_identity);

TargetBackend::addOnncIrOptimization(passManager, options);

@@ -1,5 +1,6 @@
add_libonnc_src(
DivideGlobalAPIntoAPs.cpp
EliminateIdentity.cpp
ExpandBatchNormalization.cpp
OptimizationsUtils.cpp
PropagateConstWithDiffShape.cpp
@@ -0,0 +1,63 @@
//===- EliminateIdentity.cpp ----------------------------------------------===//
//
// The ONNC Project
//
// See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
#include <onnc/Core/PassSupport.h>
#include <onnc/IR/Compute/Identity.h>
#include <onnc/Transforms/Optimizations/EliminateIdentity.h>

#include <vector>

using namespace onnc;

//===----------------------------------------------------------------------===//
// EliminateIdentity
//===----------------------------------------------------------------------===//
Pass::ReturnType EliminateIdentity::runOnModule(Module& pModule)
{
Pass::ReturnType ret = Pass::kModuleNoChanged;
Module::cg_iterator cg, cgEnd = pModule.cgEnd();
for (cg = pModule.cgBegin(); cg != cgEnd; ++cg) {
ret |= runOnComputeGraph(*cg->value());
}

if (ret != Pass::kModuleNoChanged) {
pModule.eraseUnusedValues();
}

return ret;
}

Pass::ReturnType EliminateIdentity::runOnComputeGraph(ComputeGraph& pCG)
{
std::vector<ComputeOperator*> listOfIdentityNodes;
for (ComputeOperator& node : pCG) {
if (isa<Identity>(&node)) {
listOfIdentityNodes.emplace_back(&node);
}
}

// Early return
if (listOfIdentityNodes.empty()) {
return Pass::kModuleNoChanged;
}

// Erase nodes here
for (auto* pIdentityNode : listOfIdentityNodes) {
assert(pIdentityNode->getNumOfInputs() == 1 && "Identity must have exactly one input");
assert(pIdentityNode->getNumOfOutputs() == 1 && "Identity must have exactly one output");

Value* inV = pIdentityNode->getInput(0);
Value* outV = pIdentityNode->getOutput(0);
outV->replaceAllUsesWith(*inV);

pIdentityNode->removeAllInputs();
pIdentityNode->removeAllOutputs();
pCG.erase(*pIdentityNode);
}

return Pass::kModuleChanged;
}
@@ -1,4 +1,5 @@
add_onnc_test(DivideGlobalAPIntoAPs DivideGlobalAPIntoAPsTest.cpp)
add_onnc_test(EliminateIdentityTest EliminateIdentityTest.cpp)
add_onnc_test(PropagateConstWithDiffShape PropagateConstWithDiffShapeTest.cpp)
add_onnc_test(ReplaceGemmByConv ReplaceGemmByConvTest.cpp)
add_onnc_test(SplitConv SplitConvTest.cpp)
@@ -0,0 +1,49 @@
//===- EliminateIdentityTest.cpp ------------------------------------------===//
//
// The ONNC Project
//
// See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
#include <onnc/IR/Compute/Identity.h>
#include <onnc/IR/Compute/Transpose.h>
#include <onnc/Transforms/Optimizations/EliminateIdentity.h>
#include <skypat/skypat.h>

#include "GraphUtils.h"
#include "TestUtils.h"

static void createEmptyNetwork0(Module &pM) {
ComputeGraph &cg = BuildGraph(pM, "empty_network_0");
AddInput(cg, "input_0", {3, 1, 2});
AddOutput(cg, {"input_0"});
}

static void createOneIdentity0(Module &pM) {
ComputeGraph &cg = BuildGraph(pM, "one_identity_0");
AddInput(cg, "input_0", {3, 1, 2});
AddOperator<Identity>(cg, {"input_0"}, "output_0", {3, 1, 2});
AddOutput(cg, {"output_0"});
}

static void createOneTranspose0(Module &pM) {
ComputeGraph &cg = BuildGraph(pM, "one_transpose_0");
AddInput(cg, "input_0", {3, 1, 2});
AddOperator<Transpose>(cg, {"input_0"}, "output_0", {2, 3, 1}, IntsAttr::VectorType{1, 2, 0});
AddOutput(cg, {"output_0"});
}

//===----------------------------------------------------------------------===//
// EliminateIdentity
//===----------------------------------------------------------------------===//
SKYPAT_F(EliminateIdentity, no_identity) {
std::string ans = getNetworkString(createOneTranspose0);
testOptPassOnNetwork<EliminateIdentity>(createOneTranspose0,
Pass::kModuleNoChanged, ans);
}

SKYPAT_F(EliminateIdentity, one_identity) {
std::string ans = getNetworkString(createEmptyNetwork0);
testOptPassOnNetwork<EliminateIdentity>(createOneIdentity0,
Pass::kModuleChanged, ans);
}

0 comments on commit c5e2e91

Please sign in to comment.
You can’t perform that action at this time.