Skip to content

Commit

Permalink
Add new design of profiler.
Browse files Browse the repository at this point in the history
Support c++ side.

Add test.

Add scoped timers.

Add stack empty check and include postproc/mutator apply funcs.

Add scope timers.

Revert init & async timers.

Set macro and change settings to make sure timers does not overlap.

Remove Init from ReplayTrace test.

Revert a trivial change.

Fix tests.

Fix test.

remove all `TVMTimeScope`
  • Loading branch information
zxybazh authored and junrushao committed Jun 13, 2022
1 parent 8341e33 commit b921a16
Show file tree
Hide file tree
Showing 12 changed files with 439 additions and 36 deletions.
106 changes: 106 additions & 0 deletions include/tvm/meta_schedule/profiler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_PROFILER_H_
#define TVM_META_SCHEDULE_PROFILER_H_

#include <tvm/ir/module.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>

#include <utility>
#include <vector>

namespace tvm {
namespace meta_schedule {

struct ScopedTimer {
std::function<void()> func;
explicit ScopedTimer(std::function<void()> func) : func(func) {}
~ScopedTimer() { func(); }
};

/*!
* \brief A profiler to count tuning time cost in different parts.
*/
class ProfilerNode : public runtime::Object {
public:
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("stats", &stats); }

/*!
* \brief Profile the time usage in the given scope in the given name.
* \param name Name for the scope.
* \return A scope timer for time profiling.
*/
static ScopedTimer TimeScope(String name);

/*!
* \brief Get the profiling results.
* \return The tuning profiling results as a dict.
*/
Map<String, FloatImm> Get() const { return stats; }

/*!
* \brief Start the timer for a new context.
* \param name Name of the context.
*/
void StartContextTimer(String name);

/*! \brief End the timer for the most recent context. */
void EndContextTimer();

static constexpr const char* _type_key = "meta_schedule.Profiler";
TVM_DECLARE_FINAL_OBJECT_INFO(ProfilerNode, runtime::Object);

protected:
Map<String, FloatImm> stats;
std::vector<std::pair<String, std::chrono::time_point<std::chrono::high_resolution_clock>>> stack;
};

/*!
* \brief Managed reference to ProfilerNode
* \sa ProfilerNode
*/
class Profiler : public runtime::ObjectRef {
public:
Profiler();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Profiler, runtime::ObjectRef, ProfilerNode);

protected:
friend class ProfilerInternal;

/*! \brief Entering the scope of the context manager */
void EnterWithScope();
/*! \brief Exiting the scope of the context manager */
void ExitWithScope();
};

struct ProfilerThreadLocalEntry {
Optional<Profiler> ctx;
};
using ProfilerThreadLocalStore = dmlc::ThreadLocalStore<ProfilerThreadLocalEntry>;

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_PROFILER_H_
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
search_strategy,
space_generator,
)
from .profiler import Profiler
from .apply_history_best import ApplyHistoryBest
from .extracted_task import ExtractedTask
from .relay_integration import extract_task_from_relay
Expand Down
72 changes: 72 additions & 0 deletions python/tvm/meta_schedule/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A context manager that profiles tuning time cost for different parts."""
from __future__ import annotations

import logging
from typing import Dict

from tvm._ffi import register_object
from tvm.runtime import Object

from . import _ffi_api

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


class TimeItContext:
"""The context to profile given scope."""

profiler: Profiler
name: str

def __init__(self, profiler: "Profiler", name: str):
self.profiler = profiler
self.name = name

def __enter__(self):
_ffi_api.ProfilerStartContextTimer(self.profiler, self.name) # type: ignore # pylint: disable=no-member
return self

def __exit__(self, exctype, excinst, exctb):
_ffi_api.ProfilerEndContextTimer(self.profiler) # type: ignore # pylint: disable=no-member


@register_object("meta_schedule.Profiler")
class Profiler(Object):
"""A profiler to count tuning time cost in different parts."""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.Profiler, # type: ignore # pylint: disable=no-member
)

def get(self) -> Dict[str, float]:
"""Get the profiling results in minutes"""
return _ffi_api.ProfilerGet(self) # type: ignore # pylint: disable=no-member

def timeit(self, name: str) -> TimeItContext:
return TimeItContext(self, name)

def __enter__(self) -> "Profiler":
"""Entering the scope of the context manager"""
_ffi_api.ProfilerEnterScope(self) # type: ignore # pylint: disable=no-member
return self

def __exit__(self, ptype, value, trace) -> None:
"""Exiting the scope of the context manager"""
_ffi_api.ProfilerExitScope(self) # type: ignore # pylint: disable=no-member
9 changes: 4 additions & 5 deletions src/meta_schedule/apply_history_best.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,19 @@ ApplyHistoryBest::ApplyHistoryBest(Database database, PackedFunc logging_func) {
Optional<IRModule> ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched) {
IRModule prim_mod;
GlobalVar gv;
ICHECK(dispatched.defined());
ICHECK_EQ(dispatched.value().size(), 1);
ICHECK(HasOnlyOneFunction<relay::Function>(mod)) << mod;
IRModule prim_mod = dispatched.value()[0];
prim_mod = dispatched.value()[0];
ICHECK(HasOnlyOneFunction<tir::PrimFunc>(prim_mod)) << prim_mod;

// Keep the original func name to be returned later.
GlobalVar gv = GetOnlyOneFunctionKey<tir::PrimFunc>(prim_mod).value();

gv = GetOnlyOneFunctionKey<tir::PrimFunc>(prim_mod).value();
// Unify func name to make sure it can be found in database
const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod");
ICHECK(parse_mod_func) << "Parse mod function not defined!";
prim_mod = (*parse_mod_func)(prim_mod);

if (database->HasWorkload(prim_mod)) {
Array<TuningRecord> records = database->GetTopK(database->CommitWorkload(prim_mod), 1);
if (records.size() == 1) {
Expand Down
102 changes: 102 additions & 0 deletions src/meta_schedule/profiler.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./utils.h"

namespace tvm {
namespace meta_schedule {

/**************** Context Manager ****************/

class ProfilerInternal {
public:
static void EnterScope(Profiler ctx) { ctx.EnterWithScope(); }
static void ExitScope(Profiler ctx) { ctx.ExitWithScope(); }
};

void Profiler::EnterWithScope() {
Optional<Profiler>& ctx = ProfilerThreadLocalStore::Get()->ctx;
CHECK(!ctx.defined()) << "ValueError: Nested Profiler context managers are not allowed";
ctx = *this;
}

void Profiler::ExitWithScope() {
Optional<Profiler>& ctx = ProfilerThreadLocalStore::Get()->ctx;
ICHECK(ctx.defined());
ctx = NullOpt;
}

/**************** Profiler ****************/

Profiler::Profiler() {
ObjectPtr<ProfilerNode> n = make_object<ProfilerNode>();
data_ = n;
}

ScopedTimer ProfilerNode::TimeScope(String name) {
return ScopedTimer([name, tick = std::chrono::high_resolution_clock::now()]() -> void {
Optional<Profiler> profiler = ProfilerThreadLocalStore::Get()->ctx;
if (profiler.defined()) {
Map<String, FloatImm>& stats = profiler.value()->stats;
double duration = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::high_resolution_clock::now() - tick)
.count() /
1e9 / 60;
if (stats.find(name) != stats.end()) {
stats.Set(name, FloatImm(DataType::Float(64), stats.at(name)->value + duration));
} else {
stats.Set(name, FloatImm(DataType::Float(64), duration));
}
}
});
}

void ProfilerNode::StartContextTimer(String name) {
stack.push_back(std::make_pair(name, std::chrono::high_resolution_clock::now()));
}

void ProfilerNode::EndContextTimer() {
ICHECK(stack.size() > 0) << "There is no timer context running!";
String name = stack.back().first;
double duration = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::high_resolution_clock::now() - stack.back().second)
.count() /
1e9 / 60;
if (stats.find(name) != stats.end()) {
stats.Set(name, FloatImm(DataType::Float(64), stats.at(name)->value + duration));
} else {
stats.Set(name, FloatImm(DataType::Float(64), duration));
}
stack.pop_back();
}

TVM_REGISTER_NODE_TYPE(ProfilerNode);
TVM_REGISTER_GLOBAL("meta_schedule.Profiler").set_body_typed([]() -> Profiler {
return Profiler();
});
TVM_REGISTER_GLOBAL("meta_schedule.ProfilerEnterScope")
.set_body_typed(ProfilerInternal::EnterScope);
TVM_REGISTER_GLOBAL("meta_schedule.ProfilerExitScope").set_body_typed(ProfilerInternal::ExitScope);
TVM_REGISTER_GLOBAL("meta_schedule.ProfilerStartContextTimer")
.set_body_method<Profiler>(&ProfilerNode::StartContextTimer);
TVM_REGISTER_GLOBAL("meta_schedule.ProfilerEndContextTimer")
.set_body_method<Profiler>(&ProfilerNode::EndContextTimer);
TVM_REGISTER_GLOBAL("meta_schedule.ProfilerGet").set_body_method<Profiler>(&ProfilerNode::Get);

} // namespace meta_schedule
} // namespace tvm

0 comments on commit b921a16

Please sign in to comment.