In [1]:
from nats_bench import create
from nats_bench.api_utils import time_string
import numpy as np

# Create the API for size search space
api_tss = create(None, "tss", fast_mode=False, verbose=False)

[2021-10-21 07:08:52] Try to use the default NATS-Bench (topology) path from fast_mode=False and path=/Users/xuanyidong/.torch/NATS-tss-v1_0-3ffb9.pickle.pbz2.


In [2]:
def get_valid_test_acc(api, arch, dataset):
    is_size_space = api.search_space_name == "size"
    if dataset == "cifar10":
        xinfo = api.get_more_info(
            arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
        )
        test_acc = xinfo["test-accuracy"]
        xinfo = api.get_more_info(
            arch,
            dataset="cifar10-valid",
            hp=90 if is_size_space else 200,
            is_random=False,
        )
        valid_acc = xinfo["valid-accuracy"]
    else:
        xinfo = api.get_more_info(
            arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
        )
        valid_acc = xinfo["valid-accuracy"]
        test_acc = xinfo["test-accuracy"]
    return (
        valid_acc,
        test_acc,
        "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc),
    )

def find_best_valid(api, dataset):
    all_valid_accs, all_test_accs = [], []
    for index, arch in enumerate(api):
        valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset)
        all_valid_accs.append((index, valid_acc))
        all_test_accs.append((index, test_acc))
    best_valid_index = sorted(all_valid_accs, key=lambda x: -x[1])[0][0]
    best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0]

    print("-" * 50 + "{:10s}".format(dataset) + "-" * 50)
    print(
        "Best ({:}) architecture on validation: {:}".format(
            best_valid_index, api[best_valid_index]
        )
    )
    print(
        "Best ({:}) architecture on       test: {:}".format(
            best_test_index, api[best_test_index]
        )
    )
    _, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset)
    print("using validation ::: {:}".format(perf_str))
    _, _, perf_str = get_valid_test_acc(api, best_test_index, dataset)
    print("using test       ::: {:}".format(perf_str))

dataset = "ImageNet16-120"
find_best_valid(api_tss, dataset)

--------------------------------------------------ImageNet16-120--------------------------------------------------
Best (10676) architecture on validation: |nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_1x1~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|
Best (857) architecture on       test: |nor_conv_1x1~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|
using validation ::: validation = 46.73, test = 46.20

using test       ::: validation = 46.38, test = 47.31

