Skip to content

Commit

Permalink
test=huawei_ascend_npu
Browse files Browse the repository at this point in the history
  • Loading branch information
shentanyue committed Oct 19, 2021
1 parent ebaac47 commit 26c7b73
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 30 deletions.
5 changes: 2 additions & 3 deletions lite/backends/nnadapter/nnadapter/core/operation/range.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ int PrepareRange(hal::Operation* operation) {
auto start_data = reinterpret_cast<float*>(start_operand->buffer)[0];
auto limit_data = reinterpret_cast<float*>(limit_operand->buffer)[0];
auto delta_data = reinterpret_cast<float*>(delta_operand->buffer)[0];
int64_t size = 0;
GetSize(start_data, limit_data, delta_data, &size);
output_type.dimensions.data[0] = size;
output_type.dimensions.data[0] =
GetSpanCount(start_data, limit_data, delta_data);
} else {
output_type.dimensions.data[0] = NNADAPTER_UNKNOWN;
}
Expand Down
8 changes: 4 additions & 4 deletions lite/backends/nnadapter/nnadapter/utility/utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ inline int64_t GetCurrentUS() {
}

template <typename T>
void GetSize(T start, T end, T step, int64_t* size) {
*size = std::is_integral<T>::value
? ((std::abs(end - start) + std::abs(step) - 1) / std::abs(step))
: std::ceil(std::abs((end - start) / step));
int64_t GetSpanCount(T start, T end, T step) {
return std::is_integral<T>::value
? ((std::abs(end - start) + std::abs(step) - 1) / std::abs(step))
: std::ceil(std::abs((end - start) / step));
}

} // namespace nnadapter
11 changes: 5 additions & 6 deletions lite/core/optimizer/mir/elimination/range_calc_offline_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ namespace lite {
namespace mir {

template <typename T>
void GetSize(T start, T end, T step, int64_t* size) {
*size = std::is_integral<T>::value
? ((std::abs(end - start) + std::abs(step) - 1) / std::abs(step))
: std::ceil(std::abs((end - start) / step));
int64_t GetSpanCount(T start, T end, T step) {
return std::is_integral<T>::value
? ((std::abs(end - start) + std::abs(step) - 1) / std::abs(step))
: std::ceil(std::abs((end - start) / step));
}

void RangeCalcOfflinePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
Expand Down Expand Up @@ -70,8 +70,7 @@ void RangeCalcOfflinePass::RemoveRangePattern(
auto out_t = out_var->GetMutable<lite::Tensor>();

// Calc range
int64_t size = 0;
GetSize(start, end, step, &size);
int64_t size = GetSpanCount(start, end, step);

out_t->Resize(DDim({size}));
auto out_data = out_t->mutable_data<float>();
Expand Down
16 changes: 0 additions & 16 deletions lite/core/optimizer/mir/pattern_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,22 +473,6 @@ PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type,
return this;
}

PMNode *PMNode::assert_is_not_op_input(const std::string &op_type) {
assert_is_var();
asserts_.emplace_back([=](const Node *x) {
for (auto *op : x->outlinks) {
if (op && op->IsStmt()) {
auto *op_info = op->stmt()->op_info();
if (op_info->Type() == op_type) {
return false;
}
}
}
return true;
});
return this;
}

PMNode *PMNode::assert_is_op_output(const std::string &op_type,
const std::string &argument) {
assert_is_var();
Expand Down
1 change: 0 additions & 1 deletion lite/core/optimizer/mir/pattern_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ struct PMNode {
PMNode* assert_is_op_nth_output(const std::string& op_type,
const std::string& argument,
int nth);
PMNode* assert_is_not_op_input(const std::string& op_type);

template <typename T>
PMNode* assert_op_attr_satisfied(
Expand Down

0 comments on commit 26c7b73

Please sign in to comment.