Skip to content

Commit

Permalink
Add tuning statistics.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed May 28, 2022
1 parent 45bed88 commit e9626bc
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 1 deletion.
33 changes: 33 additions & 0 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,39 @@
#include <tvm/meta_schedule/runner.h>
#include <tvm/tir/schedule/schedule.h>

#include <chrono> // NOLINT [build/c++11]
#include <map>
#include <string>

namespace tvm {
namespace meta_schedule {

struct TuningStatistics {
std::map<std::string, int64_t> total;
std::map<std::string, std::chrono::time_point<std::chrono::high_resolution_clock>> tick;

TuningStatistics() { tick["_global"] = std::chrono::high_resolution_clock::now(); }

void start_timer(std::string name) { tick[name] = std::chrono::high_resolution_clock::now(); }
int64_t count_duration_ns(std::string name) const {
std::chrono::time_point<std::chrono::high_resolution_clock> current =
std::chrono::high_resolution_clock::now();
ICHECK(tick.find(name) != tick.end());
int64_t duration =
std::chrono::duration_cast<std::chrono::nanoseconds>(current - tick.find(name)->second)
.count();
return duration;
}

void end_timer(std::string name) {
if (total.find(name) == total.end()) {
total[name] = count_duration_ns(name);
} else {
total[name] += count_duration_ns(name);
}
}
};

// Forward declaration
class TuneContext;
class CostModel;
Expand Down Expand Up @@ -101,6 +131,9 @@ class MeasureCandidate : public runtime::ObjectRef {
*/
class SearchStrategyNode : public runtime::Object {
public:
/*! \brief Tuning statistics time usage counter. */
static TuningStatistics time_counter;

/*! \brief Virtual destructor */
virtual ~SearchStrategyNode() = default;

Expand Down
4 changes: 3 additions & 1 deletion include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ class TaskSchedulerNode : public runtime::Object {
Array<MeasureCallback> measure_callbacks;
/*! \brief The number of trials already conducted. */
int num_trials_already;
/*! \brief The tuning task's logging function. t*/
/*! \brief The task scheduler's logging function. */
PackedFunc logging_func;
/*! \brief Tuning statistics time usage counter. */
TuningStatistics time_counter;

/*! \brief The default destructor. */
virtual ~TaskSchedulerNode() = default;
Expand Down
12 changes: 12 additions & 0 deletions src/meta_schedule/search_strategy/evolutionary_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,10 +505,12 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
SizedHeap heap(num);
for (int iter = 0;; ++iter) {
// Predict normalized score with the cost model,
self->time_counter.start_timer("SearchStrategy::EvolveWithCostModel::PredictNormalizedScore");
std::vector<double> scores = PredictNormalizedScore(population, //
GetRef<TuneContext>(self->context_), //
self->cost_model_, //
self->args_info_);
self->time_counter.end_timer("SearchStrategy::EvolveWithCostModel::PredictNormalizedScore");
ICHECK_EQ(scores.size(), population.size());
for (int i = 0, n = population.size(); i < n; ++i) {
Schedule sch = population.at(i);
Expand Down Expand Up @@ -567,7 +569,9 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
result = population.at(sampled_trace_id);
}
};
self->time_counter.start_timer("SearchStrategy::EvolveWithCostModel::Mutation");
support::parallel_for_dynamic(0, self->population_size, self->num_threads_, f_find_candidate);
self->time_counter.end_timer("SearchStrategy::EvolveWithCostModel::Mutation");
population.swap(next_population);
TVM_PY_LOG(INFO, self->context_->logging_func) << "Evolve iter #" << iter << " done. Summary:\n"
<< pp.SummarizeFailures();
Expand Down Expand Up @@ -657,18 +661,26 @@ Optional<Array<MeasureCandidate>> EvolutionarySearchNode::State::GenerateMeasure
inits.reserve(pop);

TVM_PY_LOG(INFO, self->context_->logging_func) << "Generating candidates......";
self->time_counter.start_timer("SearchStrategy::PickBestFromDatabase");
std::vector<Schedule> measured = PickBestFromDatabase(pop * self->init_measured_ratio);
self->time_counter.end_timer("SearchStrategy::PickBestFromDatabase");
TVM_PY_LOG(INFO, self->context_->logging_func)
<< "Picked top " << measured.size() << " candidate(s) from database";
self->time_counter.start_timer("SearchStrategy::SampleInitPopulation");
std::vector<Schedule> unmeasured = SampleInitPopulation(pop - measured.size());
self->time_counter.end_timer("SearchStrategy::SampleInitPopulation");
TVM_PY_LOG(INFO, self->context_->logging_func)
<< "Sampled " << unmeasured.size() << " candidate(s)";
inits.insert(inits.end(), measured.begin(), measured.end());
inits.insert(inits.end(), unmeasured.begin(), unmeasured.end());
self->time_counter.start_timer("SearchStrategy::EvolveWithCostModel");
std::vector<Schedule> bests = EvolveWithCostModel(inits, sample_num);
self->time_counter.end_timer("SearchStrategy::EvolveWithCostModel");
TVM_PY_LOG(INFO, self->context_->logging_func)
<< "Got " << bests.size() << " candidate(s) with evolutionary search";
self->time_counter.start_timer("SearchStrategy::PickWithEpsGreedy");
std::vector<Schedule> picks = PickWithEpsGreedy(unmeasured, bests, sample_num);
self->time_counter.end_timer("SearchStrategy::PickWithEpsGreedy");
TVM_PY_LOG(INFO, self->context_->logging_func)
<< "Sending " << picks.size() << " candidates(s) for measurement";
if (picks.empty()) {
Expand Down
2 changes: 2 additions & 0 deletions src/meta_schedule/search_strategy/search_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
namespace tvm {
namespace meta_schedule {

TuningStatistics SearchStrategyNode::time_counter;

MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array<ArgInfo> args_info) {
ObjectPtr<MeasureCandidateNode> n = make_object<MeasureCandidateNode>();
n->sch = sch;
Expand Down
30 changes: 30 additions & 0 deletions src/meta_schedule/task_scheduler/gradient_based.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,37 @@ class GradientBasedNode final : public TaskSchedulerNode {
<< "\nTotal trials: " << total_trials //
<< "\nTotal latency (us): " << total_latency //
<< "\n";

os << "\nTuning Efficiency Statistics:"
<< "\n";

support::TablePrinter p_efficiency;
double total_time = time_counter.count_duration_ns("_global");
int id_counter = 0;
p_efficiency.Row() << "ID"
<< "Module Name"
<< "Time Usage(s)"
<< "Time Ratio(%)";
p_efficiency.Separator();

p_efficiency.Row() << id_counter++ << "Total Time" << total_time / 1e9 << 100.0;

for (auto iter : time_counter.total) {
p_efficiency.Row() << id_counter++ << iter.first << iter.second / 1e9
<< iter.second / total_time * 100;
}
for (auto iter : SearchStrategyNode::time_counter.total) {
p_efficiency.Row() << id_counter++ << iter.first << iter.second / 1e9
<< iter.second / total_time * 100;
}
p_efficiency.Separator();

os << p_efficiency.AsStr();
return os.str();
}

int NextTaskId() final {
time_counter.start_timer("TaskScheduler::NextTaskId");
int n_tasks = task_records_.size();
// Round robin
if (num_rounds_already_ == 0) {
Expand Down Expand Up @@ -155,10 +182,12 @@ class GradientBasedNode final : public TaskSchedulerNode {
if (tasks[task_id]->runner_futures.defined()) {
JoinRunningTask(task_id);
}
time_counter.end_timer("TaskScheduler::NextTaskId");
return task_id;
}

Array<RunnerResult> JoinRunningTask(int task_id) final {
time_counter.start_timer("TaskScheduler::JoinRunningTask");
TaskRecord& record = task_records_[task_id];
Array<RunnerResult> results = TaskSchedulerNode::JoinRunningTask(task_id);
double& best_time_cost = this->best_time_cost_per_task_[task_id];
Expand All @@ -172,6 +201,7 @@ class GradientBasedNode final : public TaskSchedulerNode {
TVM_PY_LOG(INFO, this->logging_func)
<< "[Updated] Task #" << task_id << ": " << record.task->task_name << "\n"
<< this->TuningStatistics();
time_counter.end_timer("TaskScheduler::JoinRunningTask");
return results;
}
};
Expand Down
13 changes: 13 additions & 0 deletions src/meta_schedule/task_scheduler/task_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ void SendToRunner(const Runner& runner, const TuneContext& context, PackedFunc l
}

void TaskSchedulerNode::InitializeTask(int task_id) {
time_counter.start_timer("TaskScheduler::InitializeTask");
TuneContext task = this->tasks[task_id];
TVM_PY_LOG(INFO, task->logging_func)
<< "Initializing Task #" << task_id << ": " << task->task_name;
Expand All @@ -116,9 +117,11 @@ void TaskSchedulerNode::InitializeTask(int task_id) {
<< Concat(trace->AsPython(false), "\n");
}
task->search_strategy.value()->PreTuning(design_spaces);
time_counter.end_timer("TaskScheduler::InitializeTask");
}

void TaskSchedulerNode::Tune() {
time_counter.start_timer("TaskScheduler::Tune");
int n_tasks = this->tasks.size();
for (int task_id = 0; task_id < n_tasks; ++task_id) {
InitializeTask(task_id);
Expand All @@ -133,8 +136,13 @@ void TaskSchedulerNode::Tune() {
SearchStrategy strategy = task->search_strategy.value();
if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) {
num_trials_already += task->measure_candidates.value().size();
time_counter.start_timer("TaskScheduler::SendToBuilder");
SendToBuilder(this->builder, task, this->logging_func);
time_counter.end_timer("TaskScheduler::SendToBuilder");
time_counter.start_timer("TaskScheduler::SendToRunner");
SendToRunner(this->runner, task, this->logging_func);
time_counter.end_timer("TaskScheduler::SendToRunner");

} else {
ICHECK(!task->is_terminated);
task->is_terminated = true;
Expand All @@ -156,9 +164,11 @@ void TaskSchedulerNode::Tune() {
}
task->search_strategy.value()->PostTuning();
}
time_counter.end_timer("TaskScheduler::Tune");
}

void TaskSchedulerNode::TouchTask(int task_id) {
time_counter.start_timer("TaskScheduler::TouchTask");
TuneContext task = tasks[task_id];
if (!task->is_terminated && task->runner_futures.defined()) {
for (const RunnerFuture future : task->runner_futures.value()) {
Expand All @@ -168,9 +178,11 @@ void TaskSchedulerNode::TouchTask(int task_id) {
}
this->JoinRunningTask(task_id);
}
time_counter.end_timer("TaskScheduler::TouchTask");
}

Array<RunnerResult> TaskSchedulerNode::JoinRunningTask(int task_id) {
time_counter.start_timer("TaskScheduler::JoinRunningTask");
TuneContext task = tasks[task_id];
ICHECK(task->runner_futures.defined());
Array<RunnerFuture> futures = task->runner_futures.value();
Expand All @@ -194,6 +206,7 @@ Array<RunnerResult> TaskSchedulerNode::JoinRunningTask(int task_id) {
task->measure_candidates = NullOpt;
task->builder_results = NullOpt;
task->runner_futures = NullOpt;
time_counter.end_timer("TaskScheduler::JoinRunningTask");
return results;
}

Expand Down

0 comments on commit e9626bc

Please sign in to comment.