In [1]:
from nats_bench import create
from tqdm import tqdm
import numpy as np

def close_to(a, b, eps=1e-4):
    if b != 0 and abs(a-b) / abs(b) > eps:
        return False
    if a != 0 and abs(a-b) / abs(a) > eps:
        return False
    return True

In [2]:
def check_flops_params(xapi):
    print(f"Check {xapi}")
    datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120")
    counts = 0
    for index in tqdm(range(len(xapi))):
        for dataset in datasets:
            info_12 = xapi.get_cost_info(index, dataset, hp="12")
            info_full = xapi.get_cost_info(index, dataset, hp=xapi.full_epochs_in_paper)
            assert close_to(info_12['flops'], info_full['flops']), f"The {index}-th " \
            f"architecture has issues on {dataset} " \
            f"-- {info_12['flops']} vs {info_full['flops']}."  # check the FLOPs
            assert close_to(info_12['params'], info_full['params']), f"The {index}-th " \
            f"architecture has issues on {dataset} " \
            f"-- {info_12['params']} vs {info_full['params']}."  # check the number of parameters
            counts += 1
    print(f"Check {xapi} completed -- {counts} arch-dataset pairs.")

In [3]:
# Create the API for size search space
api = create(None, 'sss', fast_mode=True, verbose=False)
print(f'There are {len(api)} architectures in the size search space -- {api}')
check_flops_params(api)

In [4]:
# Create the API for topology search space
api = create(None, 'tss', fast_mode=True, verbose=False)
print(f'There are {len(api)} architectures in the topology search space -- {api}')
check_flops_params(api)

  0%|          | 0/15625 [00:00<?, ?it/s]

[2022-01-20 01:03:30] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.
There are 15625 architectures in the topology search space -- NATStopology(0/15625 architectures, fast_mode=True, file=None)
Check NATStopology(0/15625 architectures, fast_mode=True, file=None)


100%|██████████| 15625/15625 [20:06<00:00, 12.96it/s]

Check NATStopology(15625/15625 architectures, fast_mode=True, file=None) completed -- 62500 arch-dataset pairs.





In [5]:
# # This code block is to figure out the real reason of issue#16
# from xautodl.models import get_cell_based_tiny_net
# from xautodl.utils import count_parameters_in_MB
# from xautodl.utils import get_model_infos

# api = create(None, 'tss', fast_mode=True, verbose=False)
# print(api)

# index, dataset = 296, "cifar10"
# arch = "|skip_connect~0|+|none~0|nor_conv_3x3~1|+|avg_pool_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|"
# index = api.query_index_by_arch(arch)


# info_12 = api.get_cost_info(index, dataset, hp="12")
# info_full = api.get_cost_info(index, dataset, hp=api.full_epochs_in_paper)
# print(info_12)
# print(info_full)

# config_12 = api.get_net_config(index, dataset)
# print(config_12)
# config_full = api.get_net_config(index, dataset)
# print(config_full)

# # create the network, which is the sub-class of torch.nn.Module
# network = get_cell_based_tiny_net(config_full)

# flop, param = get_model_infos(network, (1, 3, 32, 32))
# print(f"FLOPs={flop}, param={param}")

In [6]:
# results = api.query_meta_info_by_index(index, hp=api.full_epochs_in_paper)
# print(results.all_results.keys())
# print("")
# print(results.dataset_seed[dataset])
# print(results.get_compute_costs(dataset))
# print("")
# print(results.all_results[(dataset, 777)].flop)
# print(results.all_results[(dataset, 888)].flop)
# print(results.all_results[(dataset, 999)].flop)
# print(results.all_results[('cifar100', 777)])