diff --git a/include/tvm/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h index 339f42896b66..e8c01e84f289 100755 --- a/include/tvm/auto_scheduler/measure.h +++ b/include/tvm/auto_scheduler/measure.h @@ -43,6 +43,7 @@ #include #include +#include #include namespace tvm { @@ -436,6 +437,8 @@ class ProgramMeasurerNode : public Object { std::unordered_map best_state; /*! \brief Workload key to best state's count index map. */ std::unordered_map best_ct; + /*! \brief The set of workloads that have at least one valid schedule */ + std::unordered_set has_valid; /*! \brief The ProgramBuilder to build each program. */ ProgramBuilder builder; /*! \brief The ProgramRunner to measure each program. */ diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index b9afd98be21d..ef5472d6b77e 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -20,7 +20,6 @@ import multiprocessing import logging from collections import defaultdict -import time import numpy as np @@ -138,7 +137,6 @@ def update(self, inputs, results): if len(inputs) <= 0: return assert len(inputs) == len(results) - tic = time.time() self.inputs.extend(inputs) self.results.extend(results) @@ -178,8 +176,6 @@ def update(self, inputs, results): ], ) - logger.info("XGBModel Training time: %.2f s", time.time() - tic) - def predict(self, task, states): """Predict the scores of states Parameters diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index 7c0c6ef64322..8822f3963f7b 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -249,7 +249,7 @@ def query(self, target, workload_key): if not self.silent: msg = ( - "Cannot find tuned schedule for target=%s, workload_key=%s. " + "Cannot find tuned schedules for target=%s, workload_key=%s. " "A fallback schedule is used, " "which may bring great performance regression." % (target, workload_key) ) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 642e8f85e86b..117cd4f8bc71 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -34,7 +34,6 @@ import os import time import shutil -import traceback import tempfile import multiprocessing @@ -48,10 +47,11 @@ from . import _ffi_api from .loop_state import StateObject from .utils import ( - get_const_tuple, call_func_with_timeout, - request_remote, check_remote, + get_const_tuple, + make_traceback_info, + request_remote, ) from .compute_dag import ComputeDAG from .search_task import SearchTask @@ -60,8 +60,6 @@ deserialize_workload_registry_entry, ) -# The maximum length of error message -MAX_ERROR_MSG_LEN = 512 # The time cost for measurements with errors # We use 1e10 instead of sys.float_info.max for better readability in log @@ -536,16 +534,6 @@ class MeasureErrorNo(object): UNKNOWN_ERROR = 8 # Unknown error -def make_error_msg(): - """ Get the error message from traceback. """ - error_msg = str(traceback.format_exc()) - if len(error_msg) > MAX_ERROR_MSG_LEN: - error_msg = ( - error_msg[: MAX_ERROR_MSG_LEN // 2] + "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN // 2 :] - ) - return error_msg - - def _timed_func(inp_serialized, build_func, verbose): tic = time.time() inp = MeasureInput.deserialize(inp_serialized) @@ -560,14 +548,13 @@ def _timed_func(inp_serialized, build_func, verbose): # pylint: disable=broad-except except Exception: error_no = MeasureErrorNo.INSTANTIATION_ERROR - error_msg = make_error_msg() + error_msg = make_traceback_info() if error_no == 0: dirname = tempfile.mkdtemp() filename = os.path.join(dirname, "tmp_func." + build_func.output_format) try: - # TODO(merrymercy): Port the unroll pass. with transform.PassContext(): func = build_module.build( sch, args, target=task.target, target_host=task.target_host @@ -576,7 +563,7 @@ def _timed_func(inp_serialized, build_func, verbose): # pylint: disable=broad-except except Exception: error_no = MeasureErrorNo.COMPILE_HOST - error_msg = make_error_msg() + error_msg = make_traceback_info() else: filename = "" @@ -585,6 +572,7 @@ def _timed_func(inp_serialized, build_func, verbose): print(".", end="") else: print(".E", end="") # Build error + return filename, args, error_no, error_msg, time.time() - tic @@ -615,6 +603,10 @@ def local_build_worker(args): if verbose >= 1: print(".T", end="") # Build timeout res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout + elif isinstance(res, Exception): + if verbose >= 1: + print(".E", end="") # Build error + res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout return res @@ -703,7 +695,7 @@ def _timed_eval_func( except Exception: costs = (MAX_FLOAT,) error_no = MeasureErrorNo.COMPILE_DEVICE - error_msg = make_error_msg() + error_msg = make_traceback_info() if error_no == 0: try: @@ -718,7 +710,7 @@ def _timed_eval_func( except Exception: costs = (MAX_FLOAT,) error_no = MeasureErrorNo.RUNTIME_DEVICE - error_msg = make_error_msg() + error_msg = make_traceback_info() shutil.rmtree(os.path.dirname(build_res.filename)) toc = time.time() @@ -825,6 +817,17 @@ def local_run( build_res.time_cost + timeout, time.time(), ) + elif isinstance(res, Exception): + if verbose >= 1: + print("*E", end="") # Run error + res = ( + (MAX_FLOAT,), + MeasureErrorNo.RUNTIME_DEVICE, + str(res), + build_res.time_cost + timeout, + time.time(), + ) + measure_results.append(MeasureResult(*res)) if verbose >= 1: @@ -876,7 +879,7 @@ def _timed_rpc_run( except Exception: costs = (MAX_FLOAT,) error_no = MeasureErrorNo.COMPILE_DEVICE - error_msg = make_error_msg() + error_msg = make_traceback_info() if error_no == 0: try: @@ -900,7 +903,7 @@ def _timed_rpc_run( except Exception: costs = (MAX_FLOAT,) error_no = MeasureErrorNo.RUNTIME_DEVICE - error_msg = make_error_msg() + error_msg = make_traceback_info() shutil.rmtree(os.path.dirname(build_res.filename)) toc = time.time() @@ -939,7 +942,6 @@ def _rpc_run_worker(args): ) res = call_func_with_timeout(timeout, _timed_rpc_run, args=args) - if isinstance(res, TimeoutError): if verbose >= 1: print("*T", end="") # Run timeout @@ -950,6 +952,17 @@ def _rpc_run_worker(args): build_res.time_cost + timeout, time.time(), ) + elif isinstance(res, Exception): + if verbose >= 1: + print("*E", end="") # Run error + res = ( + (MAX_FLOAT,), + MeasureErrorNo.RUNTIME_DEVICE, + str(res), + build_res.time_cost + timeout, + time.time(), + ) + return res diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py index f0d930e3257e..2569f3984f3c 100644 --- a/python/tvm/auto_scheduler/measure_record.py +++ b/python/tvm/auto_scheduler/measure_record.py @@ -14,8 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name, pointless-string-statement """ Serialization and other I/O support for measurement records (tuning logs). """ +import argparse +import logging +import os +import itertools import numpy as np @@ -24,6 +29,8 @@ from .measure import MeasureErrorNo, MeasureCallback from . import _ffi_api +logger = logging.getLogger("auto_scheduler") + @tvm._ffi.register_object("auto_scheduler.RecordToFile") class RecordToFile(MeasureCallback): @@ -36,7 +43,7 @@ class RecordToFile(MeasureCallback): File name for this callback to write log to. """ - def __init__(self, filename="auto_scheduler_tuning.json"): + def __init__(self, filename): self.__init_handle_by_constructor__(_ffi_api.RecordToFile, filename) @@ -47,11 +54,11 @@ class RecordReader(Object): Parameters ---------- - filename : str = "auto_scheduler_tuning.json" + filename : str File name for this reader to load log from. """ - def __init__(self, filename="auto_scheduler_tuning.json"): + def __init__(self, filename): self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename) def read_lines(self, max_lines=None, skip_lines=0): @@ -173,3 +180,71 @@ def load_best(filename, workload_key=None, target=None): best_res = res return best_inp, best_res + + +def distill_record_file(in_file, out_file): + """ + Pick the best entries from a record file and store them to another file. + This function distills the useful log entries from a large log file. + If out_file already exists, the best entries from both + in_file and out_file will be saved. + + Parameters + ---------- + in_file: str + The filename of input + out_file: str or file + The filename of output + """ + # pylint: disable=import-outside-toplevel + from .dispatcher import ApplyHistoryBest + + context = load_records(in_file) + if os.path.isfile(out_file): + out_context = load_records(out_file) + context = itertools.chain(context, out_context) + context, context_clone = itertools.tee(context) + best_context = ApplyHistoryBest(context) + best_set = set() + + def measure_input_str_key(inp): + return _ffi_api.SerializeMeasureInput(inp) + + for v in best_context.best_by_model.values(): + best_set.add(measure_input_str_key(v[0])) + + for v in best_context.best_by_targetkey.values(): + best_set.add(measure_input_str_key(v[0])) + + inputs = [] + results = [] + for inp, res in context_clone: + if measure_input_str_key(inp) in best_set: + inputs.append(inp) + results.append(res) + best_set.remove(measure_input_str_key(inp)) + + # create a new file and save the best records + open(out_file, "w") + save_records(out_file, inputs, results) + logger.info("Extract %d best records from %s to %s", len(inputs), in_file, out_file) + + +""" +Usage: +* Distill the best entries from a large log file +e.g. python -m tvm.auto_scheduler.measure_record --mode distill --i input.json +""" +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--mode", choices=["distill"], required=True) + parser.add_argument("--i", type=str, help="input file") + parser.add_argument("--o", type=str, default=None, help="output file") + + args = parser.parse_args() + logging.basicConfig() + logger.setLevel(logging.INFO) + + if args.mode == "distill": + args.o = args.o or args.i + ".best.json" + distill_record_file(args.i, args.o) diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index e45573be61c6..c81a4b680b95 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -31,9 +31,10 @@ from .search_policy import SearchPolicy, SketchPolicy from .cost_model import RandomModel, XGBModel -from .utils import array_mean, to_str_round +from .utils import array_mean from .measure import ProgramMeasurer from .measure_record import RecordReader +from . import _ffi_api logger = logging.getLogger("auto_scheduler") @@ -75,10 +76,10 @@ def make_search_policies( if model_type == "xgb": cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measures_per_round) if load_model_file: - logger.info("Load pretrained model...") + logger.info("TaskScheduler: Load pretrained model...") cost_model.load(load_model_file) elif load_log_file: - cost_model.load_log_file(load_log_file) + cost_model.update_from_file(load_log_file) elif model_type == "random": cost_model = RandomModel() else: @@ -137,10 +138,18 @@ class TaskScheduler: ---------- tasks: List[SearchTask] All tasks to tune + task_weights: Optional[List[float]] + The weights of tasks. + If provided, the task scheduler will set the objective function to + sum(weight[t] * latency[t]), where weight[t] is the weight of a task + and the lantecy[t] is the lantecy of the task. + If not provided, the task scheduer will assign equal weights to all + tasks (i.e., the objective function is sum(latency[t])). objective_func: Optional[Callable[List[float] -> float]] The objective function to be minimized. The objective function accepts the current latencies of all tasks and returns the - objective. If not presented, the objective is the sum of the latencies of all task. + objective. + If not provided, the objective is the weighted sum of the latencies of all tasks. strategy: str = "gradient" The scheduling strategy. "round-robin": Tune tasks in round robin order. @@ -164,20 +173,26 @@ class TaskScheduler: def __init__( self, tasks, + task_weights=None, objective_func=None, strategy="gradient", load_model_file: str = None, load_log_file: str = None, - verbose: int = 1, alpha: float = 0.2, beta: float = 2, gamma: float = 0.5, backward_window_size: int = 3, ): self.tasks = tasks - self.objective_func = objective_func or sum + if objective_func: # use custom objective function + self.objective_func = objective_func + else: # use weighted sum + if task_weights: + self.objective_func = lambda costs: sum(c * w for c, w in zip(costs, task_weights)) + else: + self.objective_func = sum + self.strategy = strategy - self.verbose = verbose self.load_log_file = load_log_file self.load_model_file = load_model_file self.alpha = alpha @@ -198,7 +213,8 @@ def __init__( self.best_costs = 1e10 * np.ones(len(self.tasks)) self.cur_score = self._compute_score(self.best_costs) - self.tune_option = self.measurer = self.search_policies = self.ct = self.tic = None + self.tune_option = self.measurer = self.search_policies = None + self.ct = self.best_ct = self.best_score = self.tic = None self.num_measures_per_round = None self.dead_tasks = set() @@ -234,14 +250,17 @@ def tune(self, tune_option, search_policy="default"): """ # init members self.tune_option = tune_option + early_stopping = 1e20 if tune_option.early_stopping < 0 else tune_option.early_stopping + self.measurer = ProgramMeasurer( tune_option.builder, tune_option.runner, tune_option.measure_callbacks, tune_option.verbose, ) - self.ct = 0 + self.ct = self.best_ct = 0 self.tic = time.time() + # reset num_measures_per_round to make sure every task is tuned at least once self.num_measures_per_round = min( tune_option.num_measures_per_round, tune_option.num_measure_trials // len(self.tasks) @@ -266,6 +285,8 @@ def tune(self, tune_option, search_policy="default"): # do a round robin first to warm up for i in range(len(self.tasks)): self._tune_task(i) + self.best_ct = self.ct + self.best_score = self.cur_score # use the specific strategy to choose workload to tune task_idx = -1 @@ -282,7 +303,7 @@ def tune(self, tune_option, search_policy="default"): continue # compute gradient from chain rule : (delta f / delta g_i) - delta = 1e-7 + delta = 1e-4 new_costs = list(self.best_costs) new_costs[i] -= delta chain_grad = ( @@ -337,10 +358,54 @@ def tune(self, tune_option, search_policy="default"): self._tune_task(task_idx) self._adjust_similarity_group(task_idx) + if self.cur_score < self.best_score: + self.best_score = self.cur_score + self.best_ct = self.ct + elif self.ct - self.best_ct >= early_stopping and all( + cost < 1e9 for cost in self.best_costs + ): + if self.tune_option.verbose >= 1: + print( + "Stop early since no performance improvement in the last " + + str(early_stopping) + + " measurement trials." + ) + break + + def _print_table_info(self, next_task_idx): + # table header + _ffi_api.PrintTitle("Task Scheduler") + print("| ID | Latency (ms) | Speed (GFLOPS) | Trials |") + print("-------------------------------------------------") + + # content + for i in range(len(self.tasks)): + id_str = "%d" % i + latency_str = "%.3f" % (1e3 * self.best_costs[i]) if self.best_costs[i] < 1e9 else "-" + speed_str = ( + "%.2f" % (self.tasks[i].compute_dag.flop_ct / self.best_costs[i] / 1e9) + if self.best_costs[i] < 1e9 + else "-" + ) + trials_str = "%d" % (self.task_cts[i] * self.num_measures_per_round) + print("| %4s | %12s | % 14s | %6s |" % (id_str, latency_str, speed_str, trials_str)) + print("-------------------------------------------------") + + # overall info + if all(cost < 1e9 for cost in self.best_costs): + total_latency_str = "%.3f" % (self.cur_score * 1e3) + else: + total_latency_str = "-" + print( + "Estimated total latency: %s ms\tTrials: %d\tUsed time : %.0f s\tNext ID: %d\t" + % (total_latency_str, self.ct, time.time() - self.tic, next_task_idx) + ) + def _tune_task(self, task_idx): """Tune the select task for one round""" - if self.verbose >= 1: - logger.info("TaskScheduler: task id:\t%d", task_idx) + if self.tune_option.verbose >= 1: + self._print_table_info(task_idx) + measure_inputs, measure_results = self.search_policies[task_idx].continue_search_one_round( self.num_measures_per_round, self.measurer ) @@ -359,17 +424,6 @@ def _tune_task(self, task_idx): self.ct += len(measure_inputs) self.cur_score = self._compute_score(self.best_costs) - if self.verbose >= 1: - logger.info( - "TaskScheduler\tct: %d\testimated cost (ms): %.3f\ttime elapsed: %.2f\t" - "best_costs (ms): %s\ttask_ct: %s", - self.ct, - self.cur_score * 1e3, - time.time() - self.tic, - to_str_round(self.best_costs * 1e3, decimal=3), - self.task_cts, - ) - def _compute_score(self, costs): """compute the objective function""" return self.objective_func(costs) diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index 0780d39e9042..9a7c199e6745 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -23,6 +23,7 @@ import queue import signal import threading +import traceback import os import numpy as np @@ -138,32 +139,49 @@ def kill_child_processes(parent_pid, sig=signal.SIGTERM): parent = psutil.Process(parent_pid) except psutil.NoSuchProcess: return - children = parent.children(recursive=True) - for process in children: - try: + + try: + children = parent.children(recursive=True) + for process in children: process.send_signal(sig) - except psutil.NoSuchProcess: - return + except psutil.NoSuchProcess: + return + + +# The maximum length of traceback information +MAX_TRACEBACK_INFO_LEN = 512 + + +def make_traceback_info(): + """ Get the error message from traceback. """ + info = str(traceback.format_exc()) + if len(info) > MAX_TRACEBACK_INFO_LEN: + info = ( + info[: MAX_TRACEBACK_INFO_LEN // 2] + "\n...\n" + info[-MAX_TRACEBACK_INFO_LEN // 2 :] + ) + return info def _func_wrapper(que, func, args, kwargs): """Call function and return the result over the queue.""" - if kwargs: - que.put(func(*args, **kwargs)) - else: - que.put(func(*args)) + try: + if kwargs: + que.put(func(*args, **kwargs)) + else: + que.put(func(*args)) + # pylint: disable=broad-except + except Exception: + que.put(Exception(make_traceback_info())) def call_func_with_timeout(timeout, func, args=(), kwargs=None): """Call a function with timeout""" - que = multiprocessing.Queue(2) process = multiprocessing.Process(target=_func_wrapper, args=(que, func, args, kwargs)) process.start() - process.join(timeout) try: - res = que.get(block=False) + res = que.get(timeout=timeout) except queue.Empty: res = TimeoutError() diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index a1b89404b5a1..4f11aea2911f 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -268,8 +268,8 @@ def split_workload(in_file, clean=True): def pick_best(in_file, out_file): """ - Pick best entries from a file and store it to another file. - This distill the useful log entries from a large log file. + Pick the best entries from a file and store them to another file. + This function distills the useful log entries from a large log file. If out_file already exists, the best entries from both in_file and out_file will be saved. diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 8d17c4bba10f..a60c87cc600d 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1345,11 +1345,6 @@ void GetPerStoreFeaturesFromStates(const Array& states, const SearchTask& GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, &(*features)[i], &error_ct); }); - - if (error_ct > 0) { - std::cerr << "Encountered " << error_ct - << " errors during feature extraction, which are safely ignored." << std::endl; - } } void GetPerStoreFeaturesFromStates(const Array& states, const std::vector& tasks, @@ -1365,11 +1360,6 @@ void GetPerStoreFeaturesFromStates(const Array& states, const std::vector GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs, &(*features)[i], &error_ct); }); - - if (error_ct > 0) { - std::cerr << "Encountered " << error_ct - << " errors during feature extraction. which are safely ignored." << std::endl; - } } void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int max_n_bufs, diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc index 6c5c10e5aaee..c77bafc84e6e 100755 --- a/src/auto_scheduler/measure.cc +++ b/src/auto_scheduler/measure.cc @@ -203,6 +203,7 @@ void ProgramMeasurerNode::Reset() { best_flops.clear(); best_ct.clear(); best_state.clear(); + has_valid.clear(); } Array ProgramMeasurerNode::Measure(const SearchTask& task, @@ -217,8 +218,7 @@ Array ProgramMeasurerNode::Measure(const SearchTask& task, batch_size = builder->n_parallel * 2; } - StdCout(verbose) << "Get " << inputs.size() << " programs for measure. (This may take a while)" - << std::endl; + StdCout(verbose) << "Get " << inputs.size() << " programs to measure." << std::endl; for (size_t i = 0; i < inputs.size(); i += batch_size) { Array input_batch(inputs.begin() + i, @@ -230,16 +230,18 @@ Array ProgramMeasurerNode::Measure(const SearchTask& task, // update current best state according to the new measure result for (size_t j = 0; j < input_batch.size(); ++j) { + const String& workload_key = input_batch[j]->task->workload_key; double flops; + if (result_batch[j]->error_no == 0) { flops = task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); error_ct = 0; + has_valid.insert(workload_key); } else { flops = 0.0; error_ct++; } - const String& workload_key = input_batch[j]->task->workload_key; if (flops > best_flops[workload_key]) { best_flops[workload_key] = flops; best_state[workload_key] = input_batch[j]->state; @@ -247,11 +249,12 @@ Array ProgramMeasurerNode::Measure(const SearchTask& task, } ct++; - StdCout(verbose) << std::fixed << std::setprecision(2) << Chars('=', 50) << "\n" - << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " - << best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j] << "\n" - << Chars('=', 50) << "\n" - << input_batch[j]->state << "\n"; + StdCout(verbose, 2) << std::fixed << std::setprecision(2) << Chars('=', 50) << "\n" + << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " + << best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j] + << "\n" + << Chars('=', 50) << "\n" + << input_batch[j]->state << "\n"; } // Call callback functions diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index b64776ff342a..4c3e8ac5593d 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -162,9 +162,17 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure Array results; while (ct < n_trials) { if (!inputs.empty()) { - // Retrain cost models before the next search round + auto tic_begin = std::chrono::high_resolution_clock::now(); + + // Retrain the cost model before the next search round PrintTitle("Train cost model", verbose); program_cost_model->Update(inputs, results); + + double duration = std::chrono::duration_cast>( + std::chrono::high_resolution_clock::now() - tic_begin) + .count(); + StdCout(verbose) << "Time elapsed: " << std::fixed << std::setprecision(2) << duration + << " s" << std::endl; } // Search one round to get promising states @@ -200,9 +208,10 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure ct += inputs.size(); // Check if reach the early stopping condition - if (ct - measurer->best_ct[search_task->workload_key] > early_stopping) { + if (ct - measurer->best_ct[search_task->workload_key] > early_stopping && + measurer->has_valid.count(search_task->workload_key)) { StdCout(verbose) << "Stop early since no performance improvement in the last " - << early_stopping << " measure steps.\n"; + << early_stopping << " measurements trials.\n"; break; } @@ -249,10 +258,18 @@ std::pair, Array> SketchPolicyNode::ContinueS measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); } + auto tic_begin = std::chrono::high_resolution_clock::now(); + // Update the cost model PrintTitle("Train cost model", verbose); program_cost_model->Update(inputs, results); + double duration = std::chrono::duration_cast>( + std::chrono::high_resolution_clock::now() - tic_begin) + .count(); + StdCout(verbose) << "Time elapsed: " << std::fixed << std::setprecision(2) << duration << " s" + << std::endl; + return std::make_pair(std::move(inputs), std::move(results)); } @@ -362,6 +379,8 @@ Array SketchPolicyNode::SampleInitPopulation(const Array& sketches // At least we should sample this number of valid programs int min_population = GetIntParam(params, SketchParamKey::SampleInitPopulation::min_population); + auto tic_begin = std::chrono::high_resolution_clock::now(); + int fail_ct = 0; Array out_states; std::vector rand_gens; @@ -369,7 +388,6 @@ Array SketchPolicyNode::SampleInitPopulation(const Array& sketches for (int i = 0; i < population; i++) { rand_gens.push_back(std::mt19937(rand_gen())); } - auto tic_begin = std::chrono::high_resolution_clock::now(); std::unordered_set explored_state_strs; size_t iter = 1; @@ -673,5 +691,9 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyEvolutionarySearch") return states; }); +TVM_REGISTER_GLOBAL("auto_scheduler.PrintTitle").set_body_typed([](std::string title) { + PrintTitle(title, 1); +}); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h index ecc46af5a5de..d59a6ca220ca 100644 --- a/src/auto_scheduler/search_policy/utils.h +++ b/src/auto_scheduler/search_policy/utils.h @@ -657,9 +657,9 @@ inline int RandomChoose(const std::vector& prefix_sum_probs, std::mt1993 /*! \brief Print a title */ inline void PrintTitle(const std::string& title, int verbose) { - StdCout(verbose) << Chars('-', 60) << "\n" - << Chars('-', 25) << " [ " << title << " ]\n" - << Chars('-', 60) << std::endl; + StdCout(verbose) << Chars('-', 70) << "\n" + << Chars('-', 30) << " [ " << title << " ]\n" + << Chars('-', 70) << std::endl; } /*! diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index ad882b3eaf24..089f51cdf047 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -30,17 +30,17 @@ def tune_network(network, target): mod, params = get_network(network) target = tvm.target.Target(target) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - objective = lambda costs: sum(c * w for c, w in zip(costs, task_weights)) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name # Tuning measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60) - tuner = auto_scheduler.TaskScheduler(tasks, objective) + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( - num_measure_trials=4, + num_measure_trials=100, num_measures_per_round=2, + early_stopping=1, runner=measure_ctx.runner, builder=auto_scheduler.LocalBuilder(timeout=60), measure_callbacks=[auto_scheduler.RecordToFile(log_file)], diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 6493c246d406..a4f3c4e06843 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -33,14 +33,14 @@ def search_common( workload=matmul_auto_scheduler_test, target="llvm", - search_policy="empty", - seed=random.randint(1, 1 << 30), + search_policy="sketch", + seed=0, runner="local", + num_measure_trials=100, cost_model=auto_scheduler.RandomModel(), - num_measure_trials=10, init_search_callbacks=None, ): - print("Test %s schedule search with the default search policy" % (target)) + print("Test search policy '%s' for '%s'" % (search_policy, target)) random.seed(seed) N = 128 @@ -59,17 +59,18 @@ def search_common( search_policy = auto_scheduler.SketchPolicy( task, program_cost_model=cost_model, init_search_callbacks=init_search_callbacks ) + else: + raise ValueError("Invalid policy: " + search_policy) tuning_options = auto_scheduler.TuningOptions( num_measure_trials=num_measure_trials, + num_measures_per_round=2, + early_stopping=1, runner=runner, - verbose=1, + verbose=2, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options) - print("*" * 80) - print(target) - print("*" * 80) inp, res = auto_scheduler.load_best(log_file, task.workload_key, target) print("==== Python Code ====") @@ -97,17 +98,30 @@ def search_common( def test_workload_registry_search_basic(): # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool - t = PropagatingThread(target=search_common, kwargs={"seed": 944563397}) + t = PropagatingThread( + target=search_common, kwargs={"search_policy": "empty", "num_measure_trials": 2} + ) t.start() t.join() + t = PropagatingThread( - target=search_common, kwargs={"seed": 944563397, "workload": "matmul_auto_scheduler_test"} + target=search_common, + kwargs={ + "workload": "matmul_auto_scheduler_test", + "num_measure_trials": 2, + "search_policy": "empty", + }, ) t.start() t.join() + t = PropagatingThread( target=search_common, - kwargs={"seed": 944563397, "workload": "matmul_auto_scheduler_test_rename_1"}, + kwargs={ + "workload": "matmul_auto_scheduler_test_rename_1", + "num_measure_trials": 2, + "search_policy": "empty", + }, ) t.start() t.join() @@ -117,9 +131,7 @@ def test_workload_registry_search_basic(): def test_sketch_search_policy_basic(): # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool - t = PropagatingThread( - target=search_common, kwargs={"seed": 944563397, "search_policy": "sketch"} - ) + t = PropagatingThread(target=search_common) t.start() t.join() @@ -144,8 +156,6 @@ def test_sketch_search_policy_xgbmodel(): t = PropagatingThread( target=search_common, kwargs={ - "seed": 944563397, - "search_policy": "sketch", "cost_model": auto_scheduler.XGBModel(), }, ) @@ -161,8 +171,6 @@ def test_sketch_search_policy_cuda_rpc_runner(): t = PropagatingThread( target=search_common, kwargs={ - "seed": 944563397, - "search_policy": "sketch", "target": "cuda", "runner": measure_ctx.runner, }, @@ -179,8 +187,6 @@ def test_sketch_search_policy_cuda_xgbmodel_rpc_runner(): t = PropagatingThread( target=search_common, kwargs={ - "seed": 944563397, - "search_policy": "sketch", "target": "cuda", "runner": measure_ctx.runner, "cost_model": auto_scheduler.XGBModel(), diff --git a/tests/python/unittest/test_auto_scheduler_task_scheduler.py b/tests/python/unittest/test_auto_scheduler_task_scheduler.py index 2debc14fc356..b0fb37a830f7 100644 --- a/tests/python/unittest/test_auto_scheduler_task_scheduler.py +++ b/tests/python/unittest/test_auto_scheduler_task_scheduler.py @@ -34,9 +34,6 @@ def test_task_scheduler_round_robin(): for n in [2, 4, 8]: tasks.append(auto_scheduler.create_task(matmul_auto_scheduler_test, (n, n, n), "llvm")) - def objective_func(costs): - return sum(costs) - with tempfile.NamedTemporaryFile() as fp: log_file = fp.name num_trials_per_task = 2 @@ -49,7 +46,7 @@ def objective_func(costs): num_measures_per_round=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) - task_scheduler = auto_scheduler.TaskScheduler(tasks, objective_func, strategy="round-robin") + task_scheduler = auto_scheduler.TaskScheduler(tasks, strategy="round-robin") task_scheduler.tune(tune_option, search_policy="sketch.random") # Check the result of round robin @@ -65,7 +62,7 @@ def objective_func(costs): # test continuous tuning (restoring the status) task_scheduler = auto_scheduler.TaskScheduler( - tasks, objective_func, strategy="round-robin", load_log_file=log_file + tasks, strategy="round-robin", load_log_file=log_file ) tune_option = auto_scheduler.TuningOptions( num_measure_trials=len(tasks), @@ -111,7 +108,7 @@ def objective_func(costs): num_measures_per_round=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) - task_scheduler = auto_scheduler.TaskScheduler(tasks, objective_func) + task_scheduler = auto_scheduler.TaskScheduler(tasks, objective_func=objective_func) # Forcely rewrite the initial values. # This can make this test more stable on the slow CI machines diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 3d229651cb4f..459b680daeb1 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -40,7 +40,7 @@ rm -rf docs/doxygen # prepare auto scheduler tutorials rm -rf tutorials/auto_scheduler/*.json -cp -f tutorials/auto_scheduler/ci_logs/{matmul,conv2d}.json tutorials/auto_scheduler +cp -f tutorials/auto_scheduler/ci_logs/*.json tutorials/auto_scheduler # remove stale tutorials and always build from scratch. rm -rf docs/tutorials diff --git a/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1.json b/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1.json new file mode 100644 index 000000000000..37a129844390 --- /dev/null +++ b/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1.json @@ -0,0 +1,23 @@ +# Provide valid schedules for resnet-18. +# This is used to run the tutorial on the documentation web server. +{"i": [["[\"d09dc1a6bb90d59c91b68989ad3492ff\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["SP", 2, 0, 1, [1, 1, 1, 1], 1], ["SP", 2, 5, 1000, [2, 5, 2, 1], 1], ["SP", 2, 10, 512, [1, 16], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 4, 0, 0, 3], ["FSP", 4, 4, 1, 3], ["RE", 4, [0, 4, 1, 5, 2, 6, 3, 7]], ["CA", 2, 4, 5], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 6], ["FU", 6, [0, 1]], ["AN", 6, 0, 5], ["FU", 6, [1, 2]], ["AN", 6, 1, 4], ["FU", 6, [2, 3]], ["AN", 6, 2, 6], ["FU", 3, [0, 1]], ["SP", 3, 0, 2, [1], 1], ["AN", 3, 1, 2], ["FFSP", 3, 0, [1, 0], 1, 1], ["AN", 3, 1, 6], ["FU", 1, [0, 1]], ["SP", 1, 0, 1, [1], 1], ["AN", 1, 1, 2], ["FFSP", 1, 0, [1, 0], 1, 1], ["AN", 1, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"]]]], "r": [[7.2561e-05], 0, 1.93892, 1605186325], "v": "v0.3"} +{"i": [["[\"8d5a93959138dc7b2ee1f1b3219dfa14\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 15], ["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [16], 1], ["SP", 8, 4, 512, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [2, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 4, 1], 1], ["SP", 6, 10, 16, [4, 2, 1, 1], 1], ["SP", 6, 15, 512, [1, 16, 1, 1], 1], ["SP", 6, 20, 512, [2, 8], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 13, 3], ["FSP", 7, 4, 14, 3], ["FSP", 7, 8, 15, 3], ["FSP", 7, 12, 16, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [8], 1], ["SP", 4, 4, 512, [16], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 19, [0, 1, 2, 3]], ["SP", 19, 0, 25088, [32], 1], ["AN", 19, 0, 5], ["AN", 19, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 64, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [16, 15, 14, 13], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 64, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [16, 15, 14, 13], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$0"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[0.000195701], 0, 2.67988, 1605186412], "v": "v0.3"} +{"i": [["[\"ac6920940de3797cc3f9f9c260675e5d\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [16], 1], ["SP", 8, 4, 512, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 2, 1], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 16, [1, 16, 1, 1], 1], ["SP", 6, 15, 512, [2, 1, 4, 1], 1], ["SP", 6, 20, 512, [32, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [16], 1], ["SP", 4, 4, 512, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 25088, [32], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 128, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 32, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$64"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$512"]]]], "r": [[0.000162045], 0, 2.32406, 1605186499], "v": "v0.3"} +{"i": [["[\"7e83a2ee5cd5d50282ed19310700046a\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [4], 1], ["SP", 8, 4, 512, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 16, [1, 1, 8, 1], 1], ["SP", 6, 15, 512, [2, 64, 1, 1], 1], ["SP", 6, 20, 512, [16, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [4], 1], ["SP", 4, 4, 512, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 25088, [32], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 16, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 128, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [2], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$512"]]]], "r": [[0.000102843], 0, 2.42044, 1605186574], "v": "v0.3"} +{"i": [["[\"424ba83160af31badc0b098136e1a3b0\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [49], 1], ["SP", 8, 4, 256, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 2, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 2], 1], ["SP", 6, 10, 49, [1, 7, 1, 7], 1], ["SP", 6, 15, 256, [1, 8, 1, 2], 1], ["SP", 6, 20, 256, [2, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [1], 1], ["SP", 4, 4, 256, [16], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 50176, [32], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 64, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 112, [2], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[9.61516e-05], 0, 2.69389, 1605186690], "v": "v0.3"} +{"i": [["[\"a169cd0053d3a7ca82998fcb62e42c58\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [7], 1], ["SP", 8, 4, 256, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 2], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 49, [1, 7, 1, 1], 1], ["SP", 6, 15, 256, [1, 4, 8, 1], 1], ["SP", 6, 20, 256, [1, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [49], 1], ["SP", 4, 4, 256, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 50176, [32], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 16, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 2, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.000156995], 0, 2.11666, 1605186772], "v": "v0.3"} +{"i": [["[\"0141ffc4fbabc10cc5a94c954419055b\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [49], 1], ["SP", 8, 4, 256, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 4, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 4, 1], 1], ["SP", 6, 10, 49, [1, 49, 1, 1], 1], ["SP", 6, 15, 256, [4, 2, 1, 1], 1], ["SP", 6, 20, 256, [1, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [7], 1], ["SP", 4, 4, 256, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 50176, [32], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [4], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 32, [4], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 32, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$64"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[0.000131082], 0, 2.24166, 1605186844], "v": "v0.3"} +{"i": [["[\"c7a6b56bdc04b94c829fb2ef9874019e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [1], 1], ["SP", 8, 4, 128, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [2, 1, 1, 1], 1], ["SP", 6, 5, 4, [2, 2, 1, 1], 1], ["SP", 6, 10, 196, [2, 7, 2, 1], 1], ["SP", 6, 15, 128, [1, 32, 1, 4], 1], ["SP", 6, 20, 128, [4, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [2], 1], ["SP", 4, 4, 128, [16], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 100352, [16], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 32, [16], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 16, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.000166673], 0, 2.43832, 1605186977], "v": "v0.3"} +{"i": [["[\"c035cc8b0568a8e054d06bd7f4950550\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 2, 1], 1], ["SP", 6, 5, 4, [1, 2, 1, 1], 1], ["SP", 6, 10, 196, [2, 49, 1, 1], 1], ["SP", 6, 15, 128, [1, 1, 4, 8], 1], ["SP", 6, 20, 128, [2, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [49], 1], ["SP", 4, 4, 128, [8], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 100352, [32], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 1024, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 32, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[0.000108367], 0, 3.89975, 1605187058], "v": "v0.3"} +{"i": [["[\"c5ee3e05edd9754492d0763aa41fd025\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 2, 2], 1], ["SP", 6, 10, 196, [1, 4, 7, 1], 1], ["SP", 6, 15, 128, [2, 16, 2, 1], 1], ["SP", 6, 20, 128, [4, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [2], 1], ["SP", 4, 4, 128, [2], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 100352, [32], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 32, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 112, [4], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[8.0137e-05], 0, 2.28468, 1605187134], "v": "v0.3"} +{"i": [["[\"022ebb6b7c55c5ed030421380ec83a04\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 2, 2, 1], 1], ["SP", 3, 10, 28, [1, 14, 1, 1], 1], ["SP", 3, 15, 128, [1, 2, 16, 1], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 64, [1, 1], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 384, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 24, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$16"]]]], "r": [[9.74847e-05], 0, 1.97907, 1605187182], "v": "v0.3"} +{"i": [["[\"de0df0893e01892cfe69f7bc2c24111f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 64, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 1, 1, 1], 1], ["SP", 6, 5, 6, [1, 1, 2, 1], 1], ["SP", 6, 10, 196, [1, 7, 14, 1], 1], ["SP", 6, 15, 64, [2, 4, 2, 1], 1], ["SP", 6, 20, 64, [1, 2], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [1], 1], ["SP", 4, 4, 64, [16], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 200704, [32], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [16], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 8, [8], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 56, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[8.09982e-05], 0, 3.52776, 1605187295], "v": "v0.3"} +{"i": [["[\"f2e3c09a00e7d0a9897f70497e089f1e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [7], 1], ["SP", 8, 4, 64, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 2, 1, 1], 1], ["SP", 6, 5, 6, [1, 3, 1, 1], 1], ["SP", 6, 10, 196, [1, 14, 1, 2], 1], ["SP", 6, 15, 64, [1, 2, 8, 2], 1], ["SP", 6, 20, 64, [4, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [4], 1], ["SP", 4, 4, 64, [4], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 200704, [64], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 512, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 64, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$64"]]]], "r": [[7.15745e-05], 0, 3.73944, 1605187404], "v": "v0.3"} +{"i": [["[\"fa26946d7ac51126bfa859cb183f9ca1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [7], 1], ["SP", 8, 4, 64, [2], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 2, 1, 3], 1], ["SP", 6, 5, 6, [1, 2, 3, 1], 1], ["SP", 6, 10, 196, [1, 4, 1, 7], 1], ["SP", 6, 15, 64, [1, 8, 2, 1], 1], ["SP", 6, 20, 64, [2, 2], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [4], 1], ["SP", 4, 4, 64, [16], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 200704, [64], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 144, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 252, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[6.79478e-05], 0, 5.10446, 1605187506], "v": "v0.3"} +{"i": [["[\"a0eb8d6048282a4a0986cc2ccf14eaa2\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 112, [2, 14, 1, 1], 1], ["SP", 3, 10, 112, [1, 8, 2, 1], 1], ["SP", 3, 15, 64, [2, 2, 2, 2], 1], ["SP", 3, 20, 7, [7, 1], 1], ["SP", 3, 23, 7, [7, 1], 1], ["SP", 3, 26, 3, [3, 1], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 1176, [21], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 189, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$512"]]]], "r": [[5.53397e-05], 0, 2.2607, 1605187548], "v": "v0.3"} +{"i": [["[\"bf78a7bf0209980f72953637dfd14a6f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 56, [2, 28, 1, 1], 1], ["SP", 3, 10, 56, [1, 2, 2, 1], 1], ["SP", 3, 15, 64, [2, 16, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [1, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 16, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[8.11163e-06], 0, 1.93343, 1605187596], "v": "v0.3"} +{"i": [["[\"6630936c26852f2b89dbfa2ff37fbb9c\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [2, 2, 1, 1], 1], ["SP", 3, 10, 28, [1, 2, 1, 1], 1], ["SP", 3, 15, 128, [2, 8, 4, 2], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [4, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 256, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 96, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$64"]]]], "r": [[1.40126e-05], 0, 1.82931, 1605187624], "v": "v0.3"} +{"i": [["[\"ba5f918733ccbbd4a1d7fd3724665a2f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 7, 1, 2], 1], ["SP", 3, 10, 14, [1, 1, 1, 2], 1], ["SP", 3, 15, 256, [4, 64, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 128, [1, 16], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 32, [16], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 324, [6], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$64"]]]], "r": [[2.35384e-05], 0, 1.78652, 1605187663], "v": "v0.3"} +{"i": [["[\"21ad409d72953de188314010134e3acd\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 1, 1], 1], ["SP", 3, 10, 7, [1, 7, 1, 1], 1], ["SP", 3, 15, 512, [1, 32, 4, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 256, [1, 64], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 16, [4], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 4, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$0"]]]], "r": [[3.09105e-05], 0, 1.85659, 1605187687], "v": "v0.3"} +{"i": [["[\"1f6cd3637ec856bf5cf5010a623eed05\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 7, 1, 1], 1], ["SP", 3, 10, 7, [1, 7, 1, 1], 1], ["SP", 3, 15, 512, [1, 8, 2, 1], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 1], 1], ["SP", 3, 26, 256, [1, 16], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 96, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 48, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.000154153], 0, 2.18601, 1605187723], "v": "v0.3"} +{"i": [["[\"81aae4b8e2c076a4014d403e8a2c70a1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 1, 2], 1], ["SP", 3, 10, 14, [1, 14, 1, 1], 1], ["SP", 3, 15, 256, [1, 32, 1, 2], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [3, 1], 1], ["SP", 3, 26, 128, [2, 4], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 144, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 72, [24], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[5.97747e-05], 0, 2.13918, 1605187759], "v": "v0.3"} diff --git a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py index d1b3c22d2084..a8bb8dd08f59 100644 --- a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py +++ b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py @@ -76,11 +76,11 @@ def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding): ###################################################################### # Next, we set parameters for the auto-scheduler. These parameters -# mainly specify how we do the measurement during the search and auto-tuning. +# mainly specify how we do the measurement during the search. # -# * :code:`measure_ctx` launches a different process for measurement. This -# provides an isolation. It can protect the master process from GPU crashes -# happended during measurement and avoid other runtime conflicts. +# * :code:`measure_ctx` launches a different process for measurement to +# provide isolation. It can protect the master process from GPU crashes +# during measurement and avoid other runtime conflicts. # * :code:`min_repeat_ms` defines the minimum duration of one "repeat" in every measurement. # This can warmup the GPU, which is necessary to get accurate measurement results. # Typically, we recommend a value > 300 ms. @@ -96,7 +96,7 @@ def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding): log_file = "conv2d.json" measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300) tune_option = auto_scheduler.TuningOptions( - num_measure_trials=10, + num_measure_trials=10, # change this to 1000 to achieve the best performance runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) diff --git a/tutorials/auto_scheduler/tune_network_cuda.py b/tutorials/auto_scheduler/tune_network_cuda.py new file mode 100644 index 000000000000..9eb5d5cdff0c --- /dev/null +++ b/tutorials/auto_scheduler/tune_network_cuda.py @@ -0,0 +1,302 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Auto-tuning a Neural Network for NVIDIA GPU +=========================================== +**Author**: `Lianmin Zheng `_ + +Auto-tuning for specific devices and workloads is critical for getting the +best performance. This is a tutorial on how to tune a whole neural +network for NVIDIA GPU with the auto-scheduler. + +To auto-tune a neural network, we partition the network into small subgraphs and +tune them independently. Each subgraph is treated as one search task. +A task scheduler slices the time and dynamically allocates time resources to +these tasks. The task scheduler predicts the impact of each task on the end-to-end +execution time and prioritizes the one that can reduce the execution time the most. + +For each subgraph, we use the compute declaration in :code:`tvm/python/topi` to +get the computational DAG in the tensor expression form. +We then use the auto-scheduler to construct a search space of this DAG and search +for good schedules (low-level optimizations). + +Different from the template-based :ref:`autotvm ` which relies on +manual templates to define the search space, the auto-scheduler does not require any +schedule templates. In other words, the auto-scheduler only uses the compute declarations +in :code:`tvm/python/topi` while does not use existing schedule templates. + +Note that this tutorial will not run on Windows or recent versions of macOS. To +get it to run, you will need to wrap the body of this tutorial in a :code:`if +__name__ == "__main__":` block. +""" + +import numpy as np + +import tvm +from tvm import relay, auto_scheduler +import tvm.relay.testing +from tvm.contrib import graph_runtime + +################################################################# +# Define a Network +# ---------------- +# First, we need to define the network with relay frontend API. +# We can load some pre-defined network from :code:`tvm.relay.testing`. +# We can also load models from MXNet, ONNX, PyTorch, and TensorFlow +# (see :ref:`front end tutorials`). +# +# Note that although auto-scheduler can work with any layouts, +# we found that the best performance is typically archived with NHWC layout +# for convolutional neural networks, so we use NHWC layout in this tutorial. +# + + +def get_network(name, batch_size, layout="NHWC", dtype="float32"): + """Get the symbol definition and random weight of a network""" + + # auto-scheduler prefers NHWC layout + if layout == "NHWC": + image_shape = (224, 224, 3) + elif layout == "NCHW": + image_shape = (3, 224, 224) + else: + raise ValueError("Invalid layout: " + layout) + + input_shape = (batch_size,) + image_shape + output_shape = (batch_size, 1000) + + if name.startswith("resnet-"): + n_layer = int(name.split("-")[1]) + mod, params = relay.testing.resnet.get_workload( + num_layers=n_layer, + batch_size=batch_size, + layout=layout, + dtype=dtype, + image_shape=image_shape, + ) + elif name.startswith("resnet3d-"): + n_layer = int(name.split("-")[1]) + mod, params = relay.testing.resnet.get_workload( + num_layers=n_layer, + batch_size=batch_size, + layout=layout, + dtype=dtype, + image_shape=image_shape, + ) + elif name == "mobilenet": + mod, params = relay.testing.mobilenet.get_workload( + batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape + ) + elif name == "squeezenet_v1.1": + mod, params = relay.testing.squeezenet.get_workload( + version="1.1", + batch_size=batch_size, + layout=layout, + dtype=dtype, + image_shape=image_shape, + ) + elif name == "inception_v3": + input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3) + mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) + elif name == "mxnet": + # an example for mxnet model + from mxnet.gluon.model_zoo.vision import get_model + + assert layout == "NCHW" + + block = get_model("resnet18_v1", pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype) + net = mod["main"] + net = relay.Function( + net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs + ) + mod = tvm.IRModule.from_expr(net) + + return mod, params, input_shape, output_shape + + +# Define the neural network and compilation target +network = "resnet-18" +batch_size = 1 +layout = "NHWC" +target = tvm.target.Target("cuda") +dtype = "float32" +log_file = "%s-%s-B%d.json" % (network, layout, batch_size) + +################################################################# +# Extract Search Tasks +# -------------------- +# Next, we extract the search tasks and their weights from a network. +# The weight of a task is the number of appearances of the task's subgraph +# in the whole network. +# By using the weight, we can approximate the end-to-end latency of the network +# as :code:`sum(latency[t] * weight[t])`, where :code:`latency[t]` is the +# latency of a task and :code:`weight[t]` is the weight of the task. +# The task scheduler will just optimize this objective. + +# Enable auto-scheduler in relay +auto_scheduler.enable_relay_integration() + +# Extract tasks from the network +print("Extract tasks...") +mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) +tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) + +################################################################# +# Begin Tuning +# ------------ +# Now, we set some options for tuning and launch the search tasks +# +# * :code:`measure_ctx` launches a different process for measurement to +# provide isolation. It can protect the master process from GPU crashes +# during measurement and avoid other runtime conflicts. +# * :code:`min_repeat_ms` defines the minimum duration of one "repeat" in every measurement. +# This can warmup the GPU, which is necessary to get accurate measurement results. +# Typically, we recommend a value > 300 ms. +# * :code:`num_measure_trials` is the number of measurement trials we can use during the tuning. +# You can set it to a small number (e.g., 200) for a fast demonstrative run. +# In practice, we recommend setting it around :code:`1000 * len(tasks)`, +# which is typically enough for the search to converge. +# For example, there are 21 tasks in resnet-18, so we can set it as 20000. +# You can adjust this parameter according to your time budget. +# * In addition, we use :code:`RecordToFile` to dump measurement records into the log file, +# The measurement records can be used to query the history best, resume the search, +# and do more analyses later. +# * see :any:`auto_scheduler.TuningOptions`, +# :any:`auto_scheduler.LocalRPCMeasureContext` for more parameters. +# + + +def run_tuning(): + print("Begin tuning...") + measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400, timeout=10) + + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=200, # change this to 20000 to achieve the best performance + runner=measure_ctx.runner, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + ) + + tuner.tune(tune_option) + + +# We do not run the tuning in our webpage server since it takes too long. +# Uncomment the following line to run it by yourself. + +# run_tuning() + + +###################################################################### +# .. note:: Explain the printed information during tuning +# +# During the tuning, a lot of information will be printed on the console. +# They are used for debugging purposes. The most important info is the output +# of the task scheduler. The following table is a sample output. +# +# .. code-block:: c +# +# ---------------------------------------------------------------------- +# ------------------------------ [ Task Scheduler ] +# ---------------------------------------------------------------------- +# | ID | Latency (ms) | Speed (GFLOPS) | Trials | +# ------------------------------------------------- +# | 0 | 0.014 | 72.07 | 64 | +# | 1 | 0.185 | 1250.68 | 128 | +# | 2 | 0.142 | 1626.36 | 192 | +# | 3 | 0.137 | 1689.42 | 128 | +# | 4 | 0.097 | 1189.75 | 128 | +# | 5 | 0.092 | 2505.25 | 128 | +# | 6 | 0.080 | 2893.08 | 128 | +# | 7 | 0.119 | 1947.84 | 128 | +# | 8 | 0.090 | 1292.62 | 64 | +# | 9 | 0.107 | 2172.30 | 64 | +# | 10 | 0.095 | 2439.36 | 64 | +# | 11 | 0.077 | 3003.22 | 64 | +# | 12 | 0.068 | 1695.13 | 64 | +# | 13 | 0.058 | 3979.29 | 64 | +# | 14 | 0.048 | 4859.95 | 128 | +# | 15 | 0.073 | 3151.76 | 64 | +# | 16 | 0.056 | 4265.94 | 64 | +# | 17 | 0.009 | 2754.90 | 64 | +# | 18 | 0.011 | 1156.08 | 64 | +# | 19 | 0.013 | 955.80 | 64 | +# | 20 | 0.029 | 437.71 | 64 | +# ------------------------------------------------- +# Estimated total latency: 1.649 ms Trials: 1920 Used time : 3598 s Next ID: 9 +# +# This table lists the latency and (estimated) speed of all tasks. +# It also lists the allocation of measurement trials for all tasks. +# The last line prints the total weighted latency of these tasks, +# which can be a rough estimation of the end-to-end execution time +# of the network. +# The last line also prints the total number of measurement trials, +# total time spent on auto-tuning and the id of the next task to tune. +# +# There will also be some "dmlc::Error"s and CUDA errors, because the +# auto-scheduler will try some invalid schedules. +# You can safely ignore them if the tuning can continue, because these +# errors are isolated from the master process. +# + +###################################################################### +# .. note:: Terminate the tuning earlier +# +# You can terminate the tuning earlier by forcibly killing this process. +# As long as you get at least one valid schedule for each task in the log file, +# you should be able to do the compilation (the secion below). +# + + +################################################################# +# Compile and Evaluate +# -------------------- +# After auto-tuning, we can compile the network with the best schedules we found. +# All measurement records are dumped into the log file during auto-tuning, +# so we can read the log file and load the best schedules. + +# Compile with the history best +print("Compile...") +with auto_scheduler.ApplyHistoryBest(log_file): + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + +# Create graph runtime +ctx = tvm.context(str(target), 0) +module = graph_runtime.GraphModule(lib["default"](ctx)) +data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) +module.set_input("data", data_tvm) + +# Evaluate +print("Evaluate inference time cost...") +ftimer = module.module.time_evaluator("run", ctx, repeat=3, min_repeat_ms=500) +prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond +print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) + + +################################################################# +# Other Tips +# -------------------- +# 1. During the tuning, the auto-scheduler needs to compile many programs and +# extract feature from them. This part is CPU-intensive, +# so a high-performance CPU with many cores is recommended for faster search. +# 2. If you have multiple GPUs, you can use all of them for measurements to +# parallelize the measurements. Check this :ref:`section ` +# to learn how to use the RPC Tracker and RPC Server. +# To use the RPC Tracker in auto-scheduler, replace the runner in :code:`TuningOptions` +# with :any:`auto_scheduler.RPCRunner`. +# diff --git a/tutorials/autotvm/tune_relay_cuda.py b/tutorials/autotvm/tune_relay_cuda.py index 3dccefef4de9..91407133d695 100644 --- a/tutorials/autotvm/tune_relay_cuda.py +++ b/tutorials/autotvm/tune_relay_cuda.py @@ -315,6 +315,7 @@ def tune_and_evaluate(tuning_opt): ################################################################# # Scale up measurement by using multiple devices # ---------------------------------------------- +# .. _tutorials-autotvm-rpc-tracker: # # If you have multiple devices, you can use all of them for measurement. # TVM uses the RPC Tracker to manage distributed devices.