Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix][Autoscheduler] Costmodel enhancement #7197

Merged
merged 12 commits into from Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 20 additions & 1 deletion python/tvm/auto_scheduler/cost_model/xgb_model.py
Expand Up @@ -19,6 +19,7 @@
"""Cost model based on xgboost"""
import multiprocessing
import logging
import os
from collections import defaultdict

import numpy as np
Expand Down Expand Up @@ -88,7 +89,7 @@ class XGBModel(PythonBasedModel):
their predictions.
"""

def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None):
def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None, model_file=None):
global xgb
try:
if xgb is None:
Expand Down Expand Up @@ -116,12 +117,17 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None):
self.plan_size = 32
self.num_warmup_sample = num_warmup_sample
self.verbose_eval = verbose_eval
self.model_file = model_file
if model_file:
logger.info("XGBModel: Load pretrained model from %s...", model_file)
self.load(model_file)
merrymercy marked this conversation as resolved.
Show resolved Hide resolved

super().__init__()

# cache measurement input/result pairs and extracted features
self.inputs = []
self.results = []
self.last_train_length = 0
self.inputs_feature_cache = []

def update(self, inputs, results):
Expand All @@ -141,6 +147,12 @@ def update(self, inputs, results):
self.inputs.extend(inputs)
self.results.extend(results)

if len(self.inputs) - self.last_train_length < self.last_train_length / 5:
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
# Set a training threshold related to `last_train_length` to reduce the training
# overhead when there're too many logs
return
self.last_train_length = len(self.inputs)

# extract feature
n_cached = len(self.inputs_feature_cache)
features, normalized_throughputs, task_ids = get_per_store_features_from_measure_pairs(
Expand Down Expand Up @@ -176,6 +188,10 @@ def update(self, inputs, results):
],
)

# Update the model file if it has been set
if self.model_file:
self.save(self.model_file)

def predict(self, task, states):
"""Predict the scores of states
Parameters
Expand Down Expand Up @@ -298,6 +314,9 @@ def load(self, file_name: str):
file_name: str
The filename
"""
if not os.path.isfile(file_name):
return

if self.bst is None:
self.bst = xgb.Booster(self.xgb_params)
self.bst.load_model(file_name)
Expand Down
11 changes: 6 additions & 5 deletions python/tvm/auto_scheduler/task_scheduler.py
Expand Up @@ -82,11 +82,12 @@ def make_search_policies(
if isinstance(search_policy, str):
policy_type, model_type = search_policy.split(".")
if model_type == "xgb":
cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measures_per_round)
if load_model_file:
logger.info("TaskScheduler: Load pretrained model...")
cost_model.load(load_model_file)
elif load_log_file:
cost_model = XGBModel(
num_warmup_sample=len(tasks) * num_measures_per_round,
model_file=load_model_file,
)
if load_log_file:
logger.info("TaskScheduler: Reload measured states and train the model...")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_model_file and load_log_file are mutually exclusive, because update_from_file will retrain a model and overwrite the loaded model.
I think the old code is better.

I don't know why the old code cannot satisfy your need.

Copy link
Contributor Author

@jcf94 jcf94 Jan 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old one is fine. I was just going to add a self.model_file for cost model saving after training, this was modified by the way.

cost_model.update_from_file(load_log_file)
elif model_type == "random":
cost_model = RandomModel()
Expand Down
18 changes: 12 additions & 6 deletions src/auto_scheduler/feature.cc
Expand Up @@ -1462,12 +1462,18 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs,
if (find_res == task_cache.end()) {
if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete
task = inputs[i]->task;
} else { // the measure input is incomplete
// rebuild task for incomplete measure pairs read from file
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
inputs[i]->task->target_host, inputs[i]->task->hardware_params,
inputs[i]->task->layout_rewrite_option);
} else {
// The measure input is incomplete, rebuild task for incomplete measure pairs read from file
try {
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
inputs[i]->task->target_host, inputs[i]->task->hardware_params,
inputs[i]->task->layout_rewrite_option);
} catch (std::exception& e) {
// Cannot build ComputeDAG from workload key, the task may have not been registered in
// this search round
continue;
comaniac marked this conversation as resolved.
Show resolved Hide resolved
}
}
task_id = task_cache.size();

Expand Down
7 changes: 4 additions & 3 deletions src/runtime/graph/debug/graph_runtime_debug.cc
Expand Up @@ -153,9 +153,10 @@ class GraphRuntimeDebug : public GraphRuntime {
const TVMContext& ctx = data_entry_[entry_id(index, 0)]->ctx;
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto op_tend = std::chrono::high_resolution_clock::now();
double op_duration =
std::chrono::duration_cast<std::chrono::duration<double> >(op_tend - op_tbegin).count();
return op_duration;
double op_duration_us =
std::chrono::duration_cast<std::chrono::duration<double> >(op_tend - op_tbegin).count() *
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
1e6;
return op_duration_us;
}

/*!
Expand Down