Skip to content

Commit

Permalink
[LITE][PASS] Fix static kernel pick pass, if op is not int8, but kern…
Browse files Browse the repository at this point in the history
…el is int8. test=develop (#2526) (#2537)
  • Loading branch information
ysh329 committed Nov 30, 2019
1 parent 3e83962 commit c4f6a1d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
4 changes: 2 additions & 2 deletions lite/core/mir/static_kernel_pick_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
<< instruct.op_type();
VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size();
for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(*kernel, graph->valid_places());
float score = KernelGrade(instruct, *kernel, graph->valid_places());
VLOG(4) << "kernel->summary():" << kernel->summary()
<< " score:" << score;
scored.emplace_back(score, std::move(kernel));
Expand Down Expand Up @@ -99,7 +99,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
instruct.ResetOp(update_desc, graph->valid_places());
scored.clear();
for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(*kernel, graph->valid_places());
float score = KernelGrade(instruct, *kernel, graph->valid_places());
scored.emplace_back(score, std::move(kernel));
}
std::sort(scored.begin(), scored.end(), KernelScoreCmp);
Expand Down
13 changes: 9 additions & 4 deletions lite/core/mir/static_kernel_pick_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class StaticKernelPickPass : public mir::StmtPass {

private:
// Score the kernel.
size_t KernelGrade(const lite::KernelBase& kernel,
size_t KernelGrade(const lite::mir::Node::Stmt& instruct,
const lite::KernelBase& kernel,
const std::vector<Place>& places) {
CHECK_GT(places.size(), 0) << "valid_places is empty.";
float final_score{-1.};
Expand All @@ -66,7 +67,7 @@ class StaticKernelPickPass : public mir::StmtPass {
// valid_places.size() as default.
// where i is the place's index in valid_places array.
// score: score is the weighted sum of target、percision and layout
for (int i = 0; i < place_size; ++i) {
for (size_t i = 0; i < place_size; ++i) {
const auto& place = places[i];
float weight = static_cast<float>(place_size - i) / place_size;
size_t score{};
Expand All @@ -83,8 +84,12 @@ class StaticKernelPickPass : public mir::StmtPass {
(place.precision == kernel.precision() ||
kernel.precision() == PRECISION(kAny) ||
place.precision == PRECISION(kAny))) {
score += kMax / static_cast<int>(
core::KernelPickFactor::Factor::PrecisionFirst);
// score skipped, if kernel is int8, but op is not int8
if (!(kernel.precision() == PRECISION(kInt8) &&
!instruct.op_info()->HasAttr("enable_int8"))) {
score += kMax / static_cast<int>(
core::KernelPickFactor::Factor::PrecisionFirst);
}
}
VLOG(4) << "[score s2]:" << score;
if (kernel_pick_factors_.IsDataLayoutConsidered() &&
Expand Down

0 comments on commit c4f6a1d

Please sign in to comment.