Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/tvm/auto_scheduler/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ class XGBModel(PythonBasedModel):
The random seed
model_file: Optional[str]
If is not None, save model to this file after every update.
adapative_training: bool = False
Whether to use adapatie training, which reduces the training frequency when there are
adaptive_training: bool = False
Whether to use adaptive training, which reduces the training frequency when there are
too many logs.
"""

Expand All @@ -109,7 +109,7 @@ def __init__(
num_warmup_sample=100,
seed=None,
model_file=None,
adapative_training=False,
adaptive_training=False,
):
global xgb
try:
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(
self.num_warmup_sample = num_warmup_sample
self.verbose_eval = verbose_eval
self.model_file = model_file
self.adapative_training = adapative_training
self.adaptive_training = adaptive_training

super().__init__()

Expand Down Expand Up @@ -169,7 +169,7 @@ def update(self, inputs, results):
self.results.extend(results)

if (
self.adapative_training
self.adaptive_training
and len(self.inputs) - self.last_train_length < self.last_train_length / 5
):
# Set a training threshold related to `last_train_length` to reduce the training
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def make_search_policies(
verbose,
load_model_file=None,
load_log_file=None,
adapative_training=False,
adaptive_training=False,
):
"""Make a list of search policies for a list of search tasks.
It creates one policy per task.
Expand All @@ -71,7 +71,7 @@ def make_search_policies(
load_log_file: Optional[str]
Load measurement records from this file. If it is not None, the status of the
task scheduler, search policies and cost models will be restored according to this file.
adapative_training: bool = False
adaptive_training: bool = False
Option used by XGBModel to reduce the model training frequency when there're too
many logs.

Expand All @@ -89,7 +89,7 @@ def make_search_policies(
cost_model = XGBModel(
num_warmup_sample=len(tasks) * num_measures_per_round,
model_file=load_model_file,
adapative_training=adapative_training,
adaptive_training=adaptive_training,
)
if load_model_file and os.path.isfile(load_model_file):
logger.info("TaskScheduler: Load pretrained model...")
Expand Down Expand Up @@ -283,7 +283,7 @@ def tune(
tune_option,
search_policy="default",
search_policy_params=None,
adapative_training=False,
adaptive_training=False,
per_task_early_stopping=None,
):
"""Tune a batch of tasks together.
Expand All @@ -300,7 +300,7 @@ def tune(
"sketch.random" for SketchPolicy + RandomModel.
search_policy_params : Optional[Dict[str, Any]]
The parameters of the search policy
adapative_training : bool = False
adaptive_training : bool = False
Option used by XGBModel to reduce the model training frequency when there're
too many logs.
per_task_early_stopping : Optional[int]
Expand Down Expand Up @@ -347,7 +347,7 @@ def tune(
tune_option.verbose,
self.load_model_file,
self.load_log_file,
adapative_training,
adaptive_training,
)

# do a round robin first to warm up
Expand Down