Skip to content

Commit

Permalink
feat: support multiple models on autotune service (#107)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
shjwudp and github-actions[bot] authored Jul 7, 2021
1 parent 328141c commit 0aec789
Show file tree
Hide file tree
Showing 25 changed files with 430 additions and 665 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
1 change: 1 addition & 0 deletions bagua/bagua_define.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BaguaHyperparameter(BaseModel):
"""

buckets: List[List[TensorDeclaration]] = []
bucket_size: int = 0
is_hierarchical_reduce: bool = False

def update(self, param_dict: dict):
Expand Down
7 changes: 5 additions & 2 deletions bagua/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def parse_args():
help="Bagua automatic super parameter search level. The higher the level, the better the theoretical effect, and the longer it takes",
)
parser.add_argument(
"--autotune_logfile", type=str, default="/tmp/bagua_autotune.log"
"--is_output_autotune_log",
action="store_true",
default=False,
help="Whether autotune output log or not. default is False",
)
parser.add_argument(
"--report_metrics",
Expand Down Expand Up @@ -152,7 +155,7 @@ def set_bagua_env(args, current_env):
args.autotune_sampling_confidence_time
)
current_env["BAGUA_AUTOTUNE_WARMUP_TIME_S"] = str(args.autotune_warmup_time)
current_env["BAGUA_AUTOTUNE_LOGFILE_PATH"] = args.autotune_logfile
current_env["BAGUA_IS_OUTPUT_AUTOTUNE_LOG"] = str(int(args.is_output_autotune_log))


def main():
Expand Down
8 changes: 6 additions & 2 deletions bagua/distributed/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,11 @@ def get_args_parser() -> ArgumentParser:
parser.add_argument("--autotune_sampling_confidence_time", type=float, default=5.0)
parser.add_argument("--autotune_warmup_time", type=float, default=30.0)
parser.add_argument(
"--autotune_logfile", type=str, default="/tmp/bagua_autotune.log"
"--is_output_autotune_log",
type=bool,
action="store_true",
default=False,
help="Whether autotune output log or not. default is False",
)

#
Expand Down Expand Up @@ -573,7 +577,7 @@ def set_bagua_env(args, current_env):
args.autotune_sampling_confidence_time
)
current_env["BAGUA_AUTOTUNE_WARMUP_TIME_S"] = str(args.autotune_warmup_time)
current_env["BAGUA_AUTOTUNE_LOGFILE_PATH"] = args.autotune_logfile
current_env["BAGUA_IS_OUTPUT_AUTOTUNE_LOG"] = args.is_output_autotune_log


def run(args):
Expand Down
7 changes: 2 additions & 5 deletions bagua/service/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
from .autotune_service import AutotuneService, BaguaHyperparameter
from .service_discovery import (
generate_and_broadcast_server_addr,
pick_n_free_ports,
)
from .autotune_service import AutotuneService, AutotuneClient # noqa: F401
from . import autotune # noqa: F401
9 changes: 7 additions & 2 deletions bagua/service/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,21 @@ class BayesianOptimizer:
Simple package of beyasian optimizer
"""

def __init__(self, param_declaration: dict, n_initial_points: int = 10):
def __init__(
self,
param_declaration: dict,
n_initial_points: int = 20,
initial_point_generator: str = "halton",
):
self.param_declaration = collections.OrderedDict(param_declaration)
search_space = [
declar.space_dimension for _, declar in self.param_declaration.items()
]

self.bayesian_optimizer = skopt.Optimizer(
dimensions=search_space,
base_estimator="GP",
n_initial_points=n_initial_points,
initial_point_generator=initial_point_generator,
n_jobs=-1,
)

Expand Down
Loading

0 comments on commit 0aec789

Please sign in to comment.