Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] Tutorial on auto-scheduling a network for GPU #6882

Merged
merged 17 commits into from Nov 13, 2020
Merged
4 changes: 0 additions & 4 deletions python/tvm/auto_scheduler/cost_model/xgb_model.py
Expand Up @@ -20,7 +20,6 @@
import multiprocessing
import logging
from collections import defaultdict
import time

import numpy as np

Expand Down Expand Up @@ -138,7 +137,6 @@ def update(self, inputs, results):
if len(inputs) <= 0:
return
assert len(inputs) == len(results)
tic = time.time()

self.inputs.extend(inputs)
self.results.extend(results)
Expand Down Expand Up @@ -178,8 +176,6 @@ def update(self, inputs, results):
],
)

logger.info("XGBModel Training time: %.2f s", time.time() - tic)

def predict(self, task, states):
"""Predict the scores of states
Parameters
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/auto_scheduler/dispatcher.py
Expand Up @@ -191,7 +191,7 @@ def load(self, records, n_lines=None):
if np.mean(other_costs) > np.mean(costs):
best_by_model[key] = (inp, res)

logger.debug("Finish loading %d records", counter)
logger.info("Finish loading %d records", counter)

def _query_inside(self, target, workload_key):
if target is None:
Expand Down Expand Up @@ -249,7 +249,7 @@ def query(self, target, workload_key):

if not self.silent:
msg = (
"Cannot find tuned schedule for target=%s, workload_key=%s. "
"Cannot find tuned schedules for target=%s, workload_key=%s. "
"A fallback schedule is used, "
"which may bring great performance regression." % (target, workload_key)
)
Expand Down
80 changes: 77 additions & 3 deletions python/tvm/auto_scheduler/measure_record.py
Expand Up @@ -14,8 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,pointless-string-statement
merrymercy marked this conversation as resolved.
Show resolved Hide resolved

""" Serialization and other I/O support for measurement records (tuning logs). """
import argparse
import logging
import os
import itertools

import numpy as np

Expand All @@ -24,6 +29,8 @@
from .measure import MeasureErrorNo, MeasureCallback
from . import _ffi_api

logger = logging.getLogger("auto_scheduler")


@tvm._ffi.register_object("auto_scheduler.RecordToFile")
class RecordToFile(MeasureCallback):
Expand All @@ -36,7 +43,7 @@ class RecordToFile(MeasureCallback):
File name for this callback to write log to.
"""

def __init__(self, filename="auto_scheduler_tuning.json"):
def __init__(self, filename):
self.__init_handle_by_constructor__(_ffi_api.RecordToFile, filename)


Expand All @@ -47,11 +54,11 @@ class RecordReader(Object):

Parameters
----------
filename : str = "auto_scheduler_tuning.json"
filename : str
File name for this reader to load log from.
"""

def __init__(self, filename="auto_scheduler_tuning.json"):
def __init__(self, filename):
self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename)

def read_lines(self, max_lines=None, skip_lines=0):
Expand Down Expand Up @@ -173,3 +180,70 @@ def load_best(filename, workload_key=None, target=None):
best_res = res

return best_inp, best_res


def distill_record_file(in_file, out_file):
"""
Pick the best entries from a record file and store them to another file.
This function distills the useful log entries from a large log file.
If out_file already exists, the best entries from both
in_file and out_file will be saved.

Parameters
----------
in_file: str
The filename of input
out_file: str or file
The filename of output
"""
# pylint: disable=import-outside-toplevel
from .dispatcher import ApplyHistoryBest

context = load_records(in_file)
if os.path.isfile(out_file):
out_context = load_records(out_file)
context = itertools.chain(context, out_context)
context, context_clone = itertools.tee(context)
best_context = ApplyHistoryBest(context)
best_set = set()

def measure_input_str_key(inp):
return _ffi_api.SerializeMeasureInput(inp)

for v in best_context.best_by_model.values():
best_set.add(measure_input_str_key(v[0]))

for v in best_context.best_by_targetkey.values():
best_set.add(measure_input_str_key(v[0]))

inputs = []
results = []
for inp, res in context_clone:
if measure_input_str_key(inp) in best_set:
inputs.append(inp)
results.append(res)
best_set.remove(measure_input_str_key(inp))

# create a new file and save the best records
open(out_file, "w")
save_records(out_file, inputs, results)
logger.info("Extract %d best records from %s to %s", len(inputs), in_file, out_file)


"""
Usage:
* Distill the best entries from a large log file
e.g. python -m tvm.auto_scheduler.measure_record --mode distill --i collect.log
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["distill"], required=True)
parser.add_argument("--i", type=str, help="input file")
parser.add_argument("--o", type=str, default=None, help="output file")

args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
merrymercy marked this conversation as resolved.
Show resolved Hide resolved

if args.mode == "distill":
args.o = args.o or args.i + ".best.json"
distill_record_file(args.i, args.o)
72 changes: 54 additions & 18 deletions python/tvm/auto_scheduler/task_scheduler.py
Expand Up @@ -31,9 +31,10 @@

from .search_policy import SearchPolicy, SketchPolicy
from .cost_model import RandomModel, XGBModel
from .utils import array_mean, to_str_round
from .utils import array_mean
from .measure import ProgramMeasurer
from .measure_record import RecordReader
from . import _ffi_api

logger = logging.getLogger("auto_scheduler")

Expand Down Expand Up @@ -75,10 +76,10 @@ def make_search_policies(
if model_type == "xgb":
cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measures_per_round)
if load_model_file:
logger.info("Load pretrained model...")
logger.info("TaskScheduler: Load pretrained model...")
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
cost_model.load(load_model_file)
elif load_log_file:
cost_model.load_log_file(load_log_file)
cost_model.update_from_file(load_log_file)
elif model_type == "random":
cost_model = RandomModel()
else:
Expand Down Expand Up @@ -137,10 +138,18 @@ class TaskScheduler:
----------
tasks: List[SearchTask]
All tasks to tune
task_weights: Optional[List[float]]
The weights of tasks.
If provided, the task scheduler will set the objective function to
sum(weight[t] * latency[t]), where weight[t] is the weight of a task
and the lantecy[t] is the lantecy of the task.
If not provided, the task scheduer will assign equal weights to all
tasks (i.e., the objective function is sum(latency[t])).
objective_func: Optional[Callable[List[float] -> float]]
The objective function to be minimized.
The objective function accepts the current latencies of all tasks and returns the
objective. If not presented, the objective is the sum of the latencies of all task.
objective.
If not provided, the objective is the weighted sum of the latencies of all tasks.
strategy: str = "gradient"
The scheduling strategy.
"round-robin": Tune tasks in round robin order.
Expand All @@ -164,6 +173,7 @@ class TaskScheduler:
def __init__(
self,
tasks,
task_weights=None,
objective_func=None,
strategy="gradient",
load_model_file: str = None,
Expand All @@ -175,7 +185,14 @@ def __init__(
backward_window_size: int = 3,
):
self.tasks = tasks
self.objective_func = objective_func or sum
if objective_func: # use custom objective function
self.objective_func = objective_func
else: # use weighted sum
if task_weights:
self.objective_func = lambda costs: sum(c * w for c, w in zip(costs, task_weights))
else:
self.objective_func = sum
merrymercy marked this conversation as resolved.
Show resolved Hide resolved

self.strategy = strategy
self.verbose = verbose
self.load_log_file = load_log_file
Expand Down Expand Up @@ -282,7 +299,7 @@ def tune(self, tune_option, search_policy="default"):
continue

# compute gradient from chain rule : (delta f / delta g_i)
delta = 1e-7
delta = 1e-4
new_costs = list(self.best_costs)
new_costs[i] -= delta
chain_grad = (
Expand Down Expand Up @@ -337,10 +354,40 @@ def tune(self, tune_option, search_policy="default"):
self._tune_task(task_idx)
self._adjust_similarity_group(task_idx)

def _print_table_info(self, next_task_idx):
# table header
_ffi_api.PrintTitle("Task Scheduler")
print("| ID | Latency (ms) | Speed (GFLOPS) | Trials |")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we extract more information? Like operator name (Conv2D, softmax...) and its shape (1x3x224x224)? Only ID, we have to match its detail information again.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not easy to extract this information by parsing the compute dag.
One way to achieve this is to attach this information by using the attrs in te.compute when defining ops in TOPI compute functions.
I leave this to future PRs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

print("-------------------------------------------------")

# content
for i in range(len(self.tasks)):
id_str = "%d" % i
latency_str = "%.3f" % (1e3 * self.best_costs[i]) if self.best_costs[i] < 1e9 else "-"
speed_str = (
"%.2f" % (self.tasks[i].compute_dag.flop_ct / self.best_costs[i] / 1e9)
if self.best_costs[i] < 1e9
else "-"
)
trials_str = "%d" % (self.task_cts[i] * self.num_measures_per_round)
print("| %4s | %12s | % 14s | %6s |" % (id_str, latency_str, speed_str, trials_str))
print("-------------------------------------------------")

# overall info
if all(self.best_costs[i] < 1e9 for i in range(len(self.tasks))):
total_latency_str = "%.3f" % (self.cur_score * 1e3)
else:
total_latency_str = "-"
print(
"Estimated total latency: %s ms\tTrials: %d\tUsed time : %.0f s\tNext ID: %d\t"
% (total_latency_str, self.ct, time.time() - self.tic, next_task_idx)
)

def _tune_task(self, task_idx):
"""Tune the select task for one round"""
if self.verbose >= 1:
logger.info("TaskScheduler: task id:\t%d", task_idx)
self._print_table_info(task_idx)

measure_inputs, measure_results = self.search_policies[task_idx].continue_search_one_round(
self.num_measures_per_round, self.measurer
)
Expand All @@ -359,17 +406,6 @@ def _tune_task(self, task_idx):
self.ct += len(measure_inputs)
self.cur_score = self._compute_score(self.best_costs)

if self.verbose >= 1:
logger.info(
"TaskScheduler\tct: %d\testimated cost (ms): %.3f\ttime elapsed: %.2f\t"
"best_costs (ms): %s\ttask_ct: %s",
self.ct,
self.cur_score * 1e3,
time.time() - self.tic,
to_str_round(self.best_costs * 1e3, decimal=3),
self.task_cts,
)

def _compute_score(self, costs):
"""compute the objective function"""
return self.objective_func(costs)
Expand Down
11 changes: 6 additions & 5 deletions python/tvm/auto_scheduler/utils.py
Expand Up @@ -138,12 +138,13 @@ def kill_child_processes(parent_pid, sig=signal.SIGTERM):
parent = psutil.Process(parent_pid)
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
for process in children:
try:

try:
children = parent.children(recursive=True)
for process in children:
process.send_signal(sig)
except psutil.NoSuchProcess:
return
except psutil.NoSuchProcess:
return


def _func_wrapper(que, func, args, kwargs):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/autotvm/record.py
Expand Up @@ -268,8 +268,8 @@ def split_workload(in_file, clean=True):

def pick_best(in_file, out_file):
"""
Pick best entries from a file and store it to another file.
This distill the useful log entries from a large log file.
Pick the best entries from a file and store them to another file.
This function distills the useful log entries from a large log file.
If out_file already exists, the best entries from both
in_file and out_file will be saved.

Expand Down
10 changes: 0 additions & 10 deletions src/auto_scheduler/feature.cc
Expand Up @@ -1345,11 +1345,6 @@ void GetPerStoreFeaturesFromStates(const Array<State>& states, const SearchTask&
GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs,
&(*features)[i], &error_ct);
});

if (error_ct > 0) {
std::cerr << "Encountered " << error_ct
<< " errors during feature extraction, which are safely ignored." << std::endl;
}
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
}

void GetPerStoreFeaturesFromStates(const Array<State>& states, const std::vector<SearchTask>& tasks,
Expand All @@ -1365,11 +1360,6 @@ void GetPerStoreFeaturesFromStates(const Array<State>& states, const std::vector
GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs,
&(*features)[i], &error_ct);
});

if (error_ct > 0) {
std::cerr << "Encountered " << error_ct
<< " errors during feature extraction. which are safely ignored." << std::endl;
}
}

void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int max_n_bufs,
Expand Down
14 changes: 7 additions & 7 deletions src/auto_scheduler/measure.cc
Expand Up @@ -217,8 +217,7 @@ Array<MeasureResult> ProgramMeasurerNode::Measure(const SearchTask& task,
batch_size = builder->n_parallel * 2;
}

StdCout(verbose) << "Get " << inputs.size() << " programs for measure. (This may take a while)"
<< std::endl;
StdCout(verbose) << "Get " << inputs.size() << " programs to measure." << std::endl;

for (size_t i = 0; i < inputs.size(); i += batch_size) {
Array<MeasureInput> input_batch(inputs.begin() + i,
Expand Down Expand Up @@ -247,11 +246,12 @@ Array<MeasureResult> ProgramMeasurerNode::Measure(const SearchTask& task,
}

ct++;
StdCout(verbose) << std::fixed << std::setprecision(2) << Chars('=', 50) << "\n"
<< "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / "
<< best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j] << "\n"
<< Chars('=', 50) << "\n"
<< input_batch[j]->state << "\n";
StdCout(verbose, 2) << std::fixed << std::setprecision(2) << Chars('=', 50) << "\n"
<< "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / "
<< best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j]
<< "\n"
<< Chars('=', 50) << "\n"
<< input_batch[j]->state << "\n";
}

// Call callback functions
Expand Down