Skip to content

Commit

Permalink
Check issue #36
Browse files Browse the repository at this point in the history
  • Loading branch information
D-X-Y committed Jan 19, 2022
1 parent b3d60c6 commit 2d82b2d
Show file tree
Hide file tree
Showing 9 changed files with 734 additions and 184 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Expand Up @@ -28,16 +28,16 @@ jobs:
cd ..
if [ "$RUNNER_OS" == "Windows" ]; then
python.exe -m pip install black
python.exe -m black NATS-Bench/nats_bench -l 120 --check --diff
python.exe -m black NATS-Bench/tests -l 120 --check --diff
python.exe -m black NATS-Bench/nats_bench -l 88 --check --diff
python.exe -m black NATS-Bench/tests -l 88 --check --diff
else
python -m pip install black
python --version
python -m black --version
echo $PWD
ls
python -m black NATS-Bench/nats_bench -l 120 --check --diff --verbose
python -m black NATS-Bench/tests -l 120 --check --diff --verbose
python -m black NATS-Bench/nats_bench -l 88 --check --diff --verbose
python -m black NATS-Bench/tests -l 88 --check --diff --verbose
fi
shell: bash

Expand Down
11 changes: 10 additions & 1 deletion nats_bench/__init__.py
Expand Up @@ -58,7 +58,16 @@ def create(file_path_or_dict, search_space, fast_mode=False, verbose=True):
def search_space_info(main_tag: Text, aux_tag: Optional[Text]):
"""Obtain the search space information."""
nats_sss = dict(candidates=[8, 16, 24, 32, 40, 48, 56, 64], num_layers=5)
nats_tss = dict(op_names=["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"], num_nodes=4)
nats_tss = dict(
op_names=[
"none",
"skip_connect",
"nor_conv_1x1",
"nor_conv_3x3",
"avg_pool_3x3",
],
num_nodes=4,
)
if main_tag == "nats-bench":
if aux_tag in NATS_BENCH_SSS_NAMEs:
return nats_sss
Expand Down
117 changes: 91 additions & 26 deletions nats_bench/api_size.py
Expand Up @@ -31,7 +31,10 @@
def print_information(information, extra_info=None, show=False):
"""print out the information of a given ArchResults."""
dataset_names = information.get_dataset_names()
strings = [information.arch_str, "datasets : {:}, extra-info : {:}".format(dataset_names, extra_info)]
strings = [
information.arch_str,
"datasets : {:}, extra-info : {:}".format(dataset_names, extra_info),
]

def metric2str(loss, acc):
return "loss = {:.3f} & top1 = {:.2f}%".format(loss, acc)
Expand All @@ -40,7 +43,12 @@ def metric2str(loss, acc):
metric = information.get_compute_costs(dataset)
flop, param, latency = metric["flops"], metric["params"], metric["latency"]
str1 = "{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.".format(
dataset, flop, param, "{:.2f}".format(latency * 1000) if latency is not None and latency > 0 else None
dataset,
flop,
param,
"{:.2f}".format(latency * 1000)
if latency is not None and latency > 0
else None,
)
train_info = information.get_metrics(dataset, "train")
if dataset == "cifar10-valid":
Expand Down Expand Up @@ -93,34 +101,48 @@ def __init__(
self.reset_time()
if file_path_or_dict is None:
if self._fast_mode:
self._archive_dir = os.path.join(get_torch_home(), "{:}-simple".format(ALL_BASE_NAMES[-1]))
self._archive_dir = os.path.join(
get_torch_home(), "{:}-simple".format(ALL_BASE_NAMES[-1])
)
else:
file_path_or_dict = os.path.join(get_torch_home(), "{:}.{:}".format(ALL_BASE_NAMES[-1], PICKLE_EXT))
file_path_or_dict = os.path.join(
get_torch_home(), "{:}.{:}".format(ALL_BASE_NAMES[-1], PICKLE_EXT)
)
print(
"{:} Try to use the default NATS-Bench (size) path from "
"fast_mode={:} and path={:}.".format(time_string(), self._fast_mode, file_path_or_dict)
"fast_mode={:} and path={:}.".format(
time_string(), self._fast_mode, file_path_or_dict
)
)
if isinstance(file_path_or_dict, str):
file_path_or_dict = str(file_path_or_dict)
if verbose:
print(
"{:} Try to create the NATS-Bench (size) api "
"from {:} with fast_mode={:}".format(time_string(), file_path_or_dict, fast_mode)
"from {:} with fast_mode={:}".format(
time_string(), file_path_or_dict, fast_mode
)
)
if not nats_is_file(file_path_or_dict) and not nats_is_dir(
file_path_or_dict
):
raise ValueError(
"{:} is neither a file or a dir.".format(file_path_or_dict)
)
if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict):
raise ValueError("{:} is neither a file or a dir.".format(file_path_or_dict))
self.filename = os.path.basename(file_path_or_dict)
if fast_mode:
if nats_is_file(file_path_or_dict):
raise ValueError(
"fast_mode={:} must feed the path for directory " ": {:}".format(fast_mode, file_path_or_dict)
"fast_mode={:} must feed the path for directory "
": {:}".format(fast_mode, file_path_or_dict)
)
else:
self._archive_dir = file_path_or_dict
else:
if nats_is_dir(file_path_or_dict):
raise ValueError(
"fast_mode={:} must feed the path for file " ": {:}".format(fast_mode, file_path_or_dict)
"fast_mode={:} must feed the path for file "
": {:}".format(fast_mode, file_path_or_dict)
)
else:
file_path_or_dict = pickle_load(file_path_or_dict)
Expand All @@ -142,32 +164,52 @@ def __init__(
hp2archres = collections.OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self._avaliable_hps.add(
hp_key
) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = set(file_path_or_dict["evaluated_indexes"])
elif self.archive_dir is not None:
benchmark_meta = pickle_load("{:}/meta.{:}".format(self.archive_dir, PICKLE_EXT))
benchmark_meta = pickle_load(
"{:}/meta.{:}".format(self.archive_dir, PICKLE_EXT)
)
self.meta_archs = copy.deepcopy(benchmark_meta["meta_archs"])
self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set()
self.evaluated_indexes = set()
else:
raise ValueError(
"file_path_or_dict [{:}] must be a dict or archive_dir " "must be set".format(type(file_path_or_dict))
"file_path_or_dict [{:}] must be a dict or archive_dir "
"must be set".format(type(file_path_or_dict))
)
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
if arch in self.archstr2index:
raise ValueError(
"This [{:}]-th arch {:} already in the " "dict ({:}).".format(idx, arch, self.archstr2index[arch])
"This [{:}]-th arch {:} already in the "
"dict ({:}).".format(idx, arch, self.archstr2index[arch])
)
self.archstr2index[arch] = idx
if self.verbose:
print(
"{:} Create NATS-Bench (size) done with {:}/{:} architectures "
"avaliable.".format(time_string(), len(self.evaluated_indexes), len(self.meta_archs))
"avaliable.".format(
time_string(), len(self.evaluated_indexes), len(self.meta_archs)
)
)

@property
def is_size(self):
return True

@property
def is_topology(self):
return False

@property
def full_epochs_in_paper(self):
return 90

def query_info_str_by_arch(self, arch, hp: Text = "12"):
"""Query the information of a specific architecture.
Expand All @@ -181,10 +223,15 @@ def query_info_str_by_arch(self, arch, hp: Text = "12"):
ArchResults instance
"""
if self.verbose:
print("{:} Call query_info_str_by_arch with arch={:}" "and hp={:}".format(time_string(), arch, hp))
print(
"{:} Call query_info_str_by_arch with arch={:}"
"and hp={:}".format(time_string(), arch, hp)
)
return self._query_info_str_by_arch(arch, hp, print_information)

def get_more_info(self, index, dataset, iepoch=None, hp: Text = "12", is_random: bool = True):
def get_more_info(
self, index, dataset, iepoch=None, hp: Text = "12", is_random: bool = True
):
"""Return the metric for the `index`-th architecture.
Args:
Expand All @@ -211,9 +258,13 @@ def get_more_info(self, index, dataset, iepoch=None, hp: Text = "12", is_random:
if self.verbose:
print(
"{:} Call the get_more_info function with index={:}, dataset={:}, "
"iepoch={:}, hp={:}, and is_random={:}.".format(time_string(), index, dataset, iepoch, hp, is_random)
"iepoch={:}, hp={:}, and is_random={:}.".format(
time_string(), index, dataset, iepoch, hp, is_random
)
)
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
index = self.query_index_by_arch(
index
) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index)
if index not in self.arch2infos_dict:
raise ValueError("Did not find {:} from arch2infos_dict.".format(index))
Expand All @@ -223,7 +274,9 @@ def get_more_info(self, index, dataset, iepoch=None, hp: Text = "12", is_random:
seeds = archresult.get_dataset_seeds(dataset)
is_random = random.choice(seeds)
# collect the training information
train_info = archresult.get_metrics(dataset, "train", iepoch=iepoch, is_random=is_random)
train_info = archresult.get_metrics(
dataset, "train", iepoch=iepoch, is_random=is_random
)
total = train_info["iepoch"] + 1
xinfo = {
"train-loss": train_info["loss"],
Expand All @@ -233,9 +286,13 @@ def get_more_info(self, index, dataset, iepoch=None, hp: Text = "12", is_random:
}
# collect the evaluation information
if dataset == "cifar10-valid":
valid_info = archresult.get_metrics(dataset, "x-valid", iepoch=iepoch, is_random=is_random)
valid_info = archresult.get_metrics(
dataset, "x-valid", iepoch=iepoch, is_random=is_random
)
try:
test_info = archresult.get_metrics(dataset, "ori-test", iepoch=iepoch, is_random=is_random)
test_info = archresult.get_metrics(
dataset, "ori-test", iepoch=iepoch, is_random=is_random
)
except Exception as unused_e: # pylint: disable=broad-except
test_info = None
valtest_info = None
Expand All @@ -253,18 +310,26 @@ def get_more_info(self, index, dataset, iepoch=None, hp: Text = "12", is_random:
)
try: # collect results on the proposed test set
if dataset == "cifar10":
test_info = archresult.get_metrics(dataset, "ori-test", iepoch=iepoch, is_random=is_random)
test_info = archresult.get_metrics(
dataset, "ori-test", iepoch=iepoch, is_random=is_random
)
else:
test_info = archresult.get_metrics(dataset, "x-test", iepoch=iepoch, is_random=is_random)
test_info = archresult.get_metrics(
dataset, "x-test", iepoch=iepoch, is_random=is_random
)
except Exception as unused_e: # pylint: disable=broad-except
test_info = None
try: # collect results on the proposed validation set
valid_info = archresult.get_metrics(dataset, "x-valid", iepoch=iepoch, is_random=is_random)
valid_info = archresult.get_metrics(
dataset, "x-valid", iepoch=iepoch, is_random=is_random
)
except Exception as unused_e: # pylint: disable=broad-except
valid_info = None
try:
if dataset != "cifar10":
valtest_info = archresult.get_metrics(dataset, "ori-test", iepoch=iepoch, is_random=is_random)
valtest_info = archresult.get_metrics(
dataset, "ori-test", iepoch=iepoch, is_random=is_random
)
else:
valtest_info = None
except Exception as unused_e: # pylint: disable=broad-except
Expand Down

0 comments on commit 2d82b2d

Please sign in to comment.