Skip to content

Commit

Permalink
Support c++ side.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Jun 10, 2022
1 parent e03a5bb commit 716e15d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
13 changes: 11 additions & 2 deletions include/tvm/meta_schedule/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
namespace tvm {
namespace meta_schedule {

struct ScopedTimer;
struct ScopedTimer {
std::function<void()> func;
explicit ScopedTimer(std::function<void()> func) : func(func) {}
~ScopedTimer() { func(); }
};

/*!
* \brief A profiler to count tuning time cost in different parts.
Expand All @@ -49,7 +53,7 @@ class ProfilerNode : public runtime::Object {
* \param name Name for the scope.
* \return A scope timer for time profiling.
*/
ScopedTimer TimeIt(String name);
static ScopedTimer TimeScope(String name);

/*!
* \brief Get the profiling results.
Expand Down Expand Up @@ -92,6 +96,11 @@ class Profiler : public runtime::ObjectRef {
void ExitWithScope();
};

struct ProfilerThreadLocalEntry {
Optional<Profiler> ctx;
};
using ProfilerThreadLocalStore = dmlc::ThreadLocalStore<ProfilerThreadLocalEntry>;

} // namespace meta_schedule
} // namespace tvm

Expand Down
41 changes: 14 additions & 27 deletions src/meta_schedule/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,12 @@ namespace meta_schedule {

/**************** Context Manager ****************/

struct ScopedTimer {
std::function<void()> func;

explicit ScopedTimer(std::function<void()> func) : func(func) {}
~ScopedTimer() {
LOG(INFO) << "destructed";
func();
}
};

struct ProfilerThreadLocalEntry {
Optional<Profiler> ctx;
};

class ProfilerInternal {
public:
static void EnterScope(Profiler ctx) { ctx.EnterWithScope(); }
static void ExitScope(Profiler ctx) { ctx.ExitWithScope(); }
};

using ProfilerThreadLocalStore = dmlc::ThreadLocalStore<ProfilerThreadLocalEntry>;

void Profiler::EnterWithScope() {
Optional<Profiler>& ctx = ProfilerThreadLocalStore::Get()->ctx;
CHECK(!ctx.defined()) << "ValueError: Nested Profiler context managers are not allowed";
Expand All @@ -64,17 +48,20 @@ Profiler::Profiler() {
data_ = n;
}

ScopedTimer ProfilerNode::TimeIt(String name) {
return ScopedTimer([this, name, tick = std::chrono::high_resolution_clock::now()]() -> void {
double duration = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::high_resolution_clock::now() - tick)
.count() /
1e9 / 60;
Map<String, FloatImm> stats = this->stats;
if (stats.find(name) != stats.end()) {
stats.Set(name, FloatImm(DataType::Float(64), stats.at(name)->value + duration));
} else {
stats.Set(name, FloatImm(DataType::Float(64), duration));
ScopedTimer ProfilerNode::TimeScope(String name) {
return ScopedTimer([name, tick = std::chrono::high_resolution_clock::now()]() -> void {
Optional<Profiler> profiler = ProfilerThreadLocalStore::Get()->ctx;
if (profiler.defined()) {
Map<String, FloatImm>& stats = profiler.value()->stats;
double duration = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::high_resolution_clock::now() - tick)
.count() /
1e9 / 60;
if (stats.find(name) != stats.end()) {
stats.Set(name, FloatImm(DataType::Float(64), stats.at(name)->value + duration));
} else {
stats.Set(name, FloatImm(DataType::Float(64), duration));
}
}
});
}
Expand Down
5 changes: 5 additions & 0 deletions src/meta_schedule/search_strategy/replay_trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ReplayTraceNode : public SearchStrategyNode {
TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode);

void InitializeWithTuneContext(const TuneContext& context) final {
ScopedTimer timer = ProfilerNode::TimeScope("ReplayTrace::InitializeWithTuneContext");
CHECK(context->mod.defined()) << "ValueError: TuneContext.mod is not defined";
this->context_ = context.get();
this->rand_state_ = ForkSeed(&context->rand_state);
Expand All @@ -91,6 +92,7 @@ class ReplayTraceNode : public SearchStrategyNode {

void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
const Optional<CostModel>& cost_model) final {
ScopedTimer timer = ProfilerNode::TimeScope("ReplayTrace::PreTuning");
ICHECK(!design_spaces.empty());
CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?";
if (this->state_ != nullptr) {
Expand All @@ -107,17 +109,20 @@ class ReplayTraceNode : public SearchStrategyNode {
}

void PostTuning() final {
ScopedTimer timer = ProfilerNode::TimeScope("ReplayTrace::PostTuning");
ICHECK(this->state_ != nullptr);
this->state_.reset();
}

Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
ScopedTimer timer = ProfilerNode::TimeScope("ReplayTrace::GenerateMeasureCandidates");
ICHECK(this->state_ != nullptr);
return this->state_->GenerateMeasureCandidates();
}

void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
const Array<RunnerResult>& results) final {
ScopedTimer timer = ProfilerNode::TimeScope("ReplayTrace::NotifyRunnerResults");
ICHECK(this->state_ != nullptr);
this->state_->NotifyRunnerResults(results);
}
Expand Down

0 comments on commit 716e15d

Please sign in to comment.