Skip to content

Commit

Permalink
[MetaSchedule] Add Profiler Support For Tuning Efficiency Optimization (
Browse files Browse the repository at this point in the history
#11486)

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
  • Loading branch information
zxybazh and junrushao committed Jun 13, 2022
1 parent 2df4524 commit e61ad7a
Show file tree
Hide file tree
Showing 15 changed files with 434 additions and 55 deletions.
103 changes: 103 additions & 0 deletions include/tvm/meta_schedule/profiler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_PROFILER_H_
#define TVM_META_SCHEDULE_PROFILER_H_

#include <tvm/ir/module.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>

#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
namespace meta_schedule {

class ScopedTimer {
public:
~ScopedTimer() {
if (deferred_ != nullptr) {
deferred_();
}
}

private:
friend class Profiler;

explicit ScopedTimer(runtime::TypedPackedFunc<void()> deferred) : deferred_(deferred) {}
runtime::TypedPackedFunc<void()> deferred_;
};

/*! \brief A generic profiler */
class ProfilerNode : public runtime::Object {
public:
/*! \brief The segments that are already profiled */
std::unordered_map<std::string, double> stats_sec;
/*! \brief Counter for the total time used */
runtime::PackedFunc total_timer;

void VisitAttrs(tvm::AttrVisitor* v) {
// `stats_sec` is not visited.
// `total_timer` is not visited.
}

static constexpr const char* _type_key = "meta_schedule.Profiler";
TVM_DECLARE_FINAL_OBJECT_INFO(ProfilerNode, runtime::Object);

public:
/*! \brief Get the internal stats of the running time */
Map<String, FloatImm> Get() const;
/*! \brief Return a summary of profiling results as table format */
String Table() const;
};

/*!
* \brief Managed reference to ProfilerNode
* \sa ProfilerNode
*/
class Profiler : public runtime::ObjectRef {
public:
Profiler();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Profiler, runtime::ObjectRef, ProfilerNode);

/*! \brief Entering the scope of the context manager */
void EnterWithScope();
/*! \brief Exiting the scope of the context manager */
void ExitWithScope();
/*! \brief Returns the current profiler */
static Optional<Profiler> Current();
/*!
* \brief Profile the time usage in the given scope in the given name.
* \param name Name for the scope.
* \return A scope timer for time profiling.
*/
static ScopedTimer TimedScope(String name);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_PROFILER_H_
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
search_strategy,
space_generator,
)
from .profiler import Profiler
from .apply_history_best import ApplyHistoryBest
from .extracted_task import ExtractedTask
from .relay_integration import extract_task_from_relay
Expand Down
76 changes: 76 additions & 0 deletions python/tvm/meta_schedule/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A context manager that profiles tuning time cost for different parts."""
from __future__ import annotations

import logging
from contextlib import contextmanager
from typing import Dict, Optional

from tvm._ffi import register_object
from tvm.runtime import Object

from . import _ffi_api

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@register_object("meta_schedule.Profiler")
class Profiler(Object):
"""Tuning time profiler."""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.Profiler, # type: ignore # pylint: disable=no-member
)

def get(self) -> Dict[str, float]:
"""Get the profiling results in minutes"""
return _ffi_api.ProfilerGet(self) # type: ignore # pylint: disable=no-member

def table(self) -> str:
"""Get the profiling results in a table format"""
return _ffi_api.ProfilerTable(self) # type: ignore # pylint: disable=no-member

def __enter__(self) -> "Profiler":
"""Entering the scope of the context manager"""
_ffi_api.ProfilerEnterWithScope(self) # type: ignore # pylint: disable=no-member
return self

def __exit__(self, ptype, value, trace) -> None:
"""Exiting the scope of the context manager"""
_ffi_api.ProfilerExitWithScope(self) # type: ignore # pylint: disable=no-member

@staticmethod
def current() -> Optional["Profiler"]:
"""Get the current profiler."""
return _ffi_api.ProfilerCurrent() # type: ignore # pylint: disable=no-member

@staticmethod
def timeit(name: str):
"""Timeit a block of code"""

@contextmanager
def _timeit():
try:
f = _ffi_api.ProfilerTimedScope(name) # type: ignore # pylint: disable=no-member
yield
finally:
if f:
f()

return _timeit()
29 changes: 16 additions & 13 deletions python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,22 @@ def main():
alloc_repeat=alloc_repeat,
max_workers=ARGS.rpc_workers,
)
lib = ms.tune_relay(
mod=mod,
target=ARGS.target,
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
),
runner=runner, # type: ignore
work_dir=ARGS.work_dir,
params=params,
)
with ms.Profiler() as profiler:
lib = ms.tune_relay(
mod=mod,
target=ARGS.target,
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
),
runner=runner, # type: ignore
work_dir=ARGS.work_dir,
params=params,
)
print("Tuning Time:")
print(profiler.table())
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
for input_name, input_shape in input_info.items():
if input_dtype.startswith("float"):
Expand Down
1 change: 1 addition & 0 deletions src/meta_schedule/measure_callback/add_to_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class AddToDatabaseNode : public MeasureCallbackNode {
if (!task_scheduler->database.defined()) {
return;
}
auto _ = Profiler::TimedScope("AddToDatabase");
TuneContext task = task_scheduler->tasks[task_id];
Database database = task_scheduler->database.value();
Workload workload = database->CommitWorkload(task->mod.value());
Expand Down
1 change: 1 addition & 0 deletions src/meta_schedule/measure_callback/echo_statistics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class EchoStatisticsNode : public MeasureCallbackNode {
if (this->task_info.empty()) {
SetupTaskInfo(task_scheduler->tasks);
}
auto _ = Profiler::TimedScope("EchoStatistics");
ICHECK_EQ(measure_candidates.size(), builder_results.size());
ICHECK_EQ(measure_candidates.size(), runner_results.size());
int n = measure_candidates.size();
Expand Down
1 change: 1 addition & 0 deletions src/meta_schedule/measure_callback/measure_callback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void PyMeasureCallbackNode::Apply(const TaskScheduler& task_scheduler,
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) {
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
auto _ = Profiler::TimedScope(this->f_as_string());
return f_apply(task_scheduler, task_id, measure_candidates, builds, results);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class RemoveBuildArtifactNode : public MeasureCallbackNode {
const Array<BuilderResult>& builder_results,
const Array<RunnerResult>& runner_results) final {
static const PackedFunc* f_rm = runtime::Registry::Get("meta_schedule.remove_build_dir");
auto _ = Profiler::TimedScope("RemoveBuildArtifact");
for (const BuilderResult& build_result : builder_results) {
if (Optional<String> path = build_result->artifact_path) {
(*f_rm)(path.value());
Expand Down
6 changes: 3 additions & 3 deletions src/meta_schedule/measure_callback/update_cost_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class UpdateCostModelNode : public MeasureCallbackNode {
const Array<MeasureCandidate>& measure_candidates,
const Array<BuilderResult>& builder_results,
const Array<RunnerResult>& runner_results) final {
auto _ = Profiler::TimedScope("UpdateCostModel");
TuneContext task = task_scheduler->tasks[task_id];
ICHECK(task_scheduler->cost_model.defined()) //
ICHECK(task_scheduler->cost_model.defined())
<< "Cost model must be defined for the task scheduler!";
ICHECK(task->measure_candidates.defined()) //
<< "Task's measure candidates must be present!";
ICHECK(task->measure_candidates.defined()) << "Task's measure candidates must be present!";
CostModel cost_model = task_scheduler->cost_model.value();
ICHECK_EQ(measure_candidates.size(), builder_results.size());
ICHECK_EQ(runner_results.size(), builder_results.size());
Expand Down

0 comments on commit e61ad7a

Please sign in to comment.