In [1]:
import os, sys, time, glob, random, argparse
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
import time
import tqdm
import scipy.stats as stats
import matplotlib.pyplot as plt
import pickle

# XAutoDL 
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, get_nas_search_loaders
from xautodl.procedures import (
    prepare_seed,
    prepare_logger,
    save_checkpoint,
    copy_checkpoint,
    get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_search_spaces

# API
from nats_bench import create

# custom modules
from custom.tss_model_supernet import get_cell_based_tiny_net
from ZeroShotProxy import *

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

2023-09-29 22:17:11.468586: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-09-29 22:17:11.514848: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
parser = argparse.ArgumentParser("Training-free NAS on NATSBench (TSS)")
parser.add_argument("--data_path", type=str, default='./cifar.python', help="The path to dataset")
parser.add_argument("--dataset", type=str, default='cifar10',choices=["cifar10", "cifar100", "ImageNet16-120"], help="Choose between Cifar10/100 and ImageNet-16.")

# channels and number-of-cells
parser.add_argument("--search_space", type=str, default='tss', help="The search space name.")
parser.add_argument("--config_path", type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help="The path to the configuration.")
parser.add_argument("--max_nodes", type=int, default=4, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, default=16, help="The number of channels.")
parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.")
# parser.add_argument("--channel", type=int, default=3, help="The number of channels.")
# parser.add_argument("--num_cells", type=int, default=1, help="The number of cells in one stage.")
# parser.add_argument("--select_num", type=int, default=100, help="The number of selected architectures to evaluate.")
parser.add_argument("--affine", type=int, default=0, choices=[0, 1], help="Whether use affine=True or False in the BN layer.")
parser.add_argument("--track_running_stats", type=int, default=0, choices=[0, 1], help="Whether use track_running_stats or not in the BN layer.")

# log
parser.add_argument("--print_freq", type=int, default=200, help="print frequency (default: 200)")

# custom
parser.add_argument("--gpu", type=int, default=0, help="")
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers")
parser.add_argument("--api_data_path", type=str, default="./api_data/NATS-tss-v1_0-3ffb9-simple/", help="")
parser.add_argument("--save_dir", type=str, default='./results/tmp', help="Folder to save checkpoints and log.")
parser.add_argument('--zero_shot_score', type=str, default='gradsign', help='could be: ETF; ZiCo (form ZiCo); params, (for \#Params); Zen (for Zen-NAS); TE (for TE-NAS)')
parser.add_argument("--rand_seed", type=int, default=1, help="manual seed")
args = parser.parse_args(args=[])

if args.rand_seed is None or args.rand_seed < 0:
    args.rand_seed = random.randint(1, 100000)

print(args.rand_seed)
print(args)
xargs=args

1
Namespace(data_path='./cifar.python', dataset='cifar10', search_space='tss', config_path='./configs/nas-benchmark/algos/weight-sharing.config', max_nodes=4, channel=16, num_cells=5, affine=0, track_running_stats=0, print_freq=200, gpu=0, workers=2, api_data_path='./api_data/NATS-tss-v1_0-3ffb9-simple/', save_dir='./results/tmp', zero_shot_score='gradsign', rand_seed=1)


In [3]:
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)

Main Function with logger : Logger(dir=results/tmp, use-tf=False, writer=None)
Arguments : -------------------------------
data_path        : ./cifar.python
dataset          : cifar10
search_space     : tss
config_path      : ./configs/nas-benchmark/algos/weight-sharing.config
max_nodes        : 4
channel          : 16
num_cells        : 5
affine           : 0
track_running_stats : 0
print_freq       : 200
gpu              : 0
workers          : 2
api_data_path    : ./api_data/NATS-tss-v1_0-3ffb9-simple/
save_dir         : ./results/tmp
zero_shot_score  : gradsign
rand_seed        : 1
Python  Version  : 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0]
Pillow  Version  : 9.4.0
PyTorch Version  : 2.0.1
cuDNN   Version  : 8500
CUDA available   : True
CUDA GPU numbers : 1
CUDA_VISIBLE_DEVICES : 0


In [4]:
## API
api = create(xargs.api_data_path, xargs.search_space, fast_mode=True, verbose=False)
logger.log("Create API = {:} done".format(api))

## data
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger)
search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data,
                                                                   valid_data,
                                                                   xargs.dataset,
                                                                   "./configs/nas-benchmark/",
                                                                   (config.batch_size, config.test_batch_size),
                                                                   xargs.workers,)
logger.log("||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))

## model
search_space = get_search_spaces(xargs.search_space, "nats-bench")
model_config = dict2config(
    dict(
        name="generic",
        C=xargs.channel,
        N=xargs.num_cells,
        max_nodes=xargs.max_nodes,
        num_classes=class_num,
        space=search_space,
        affine=bool(xargs.affine),
        track_running_stats=bool(xargs.track_running_stats),
    ),
    None,
)
logger.log("search space : {:}".format(search_space))
logger.log("model config : {:}".format(model_config))
search_model = get_cell_based_tiny_net(model_config)

device = torch.device('cuda:{}'.format(xargs.gpu))
network = search_model.to(device)

Create API = NATStopology(0/15625 architectures, fast_mode=True, file=) done
Files already downloaded and verified
Files already downloaded and verified
./configs/nas-benchmark/algos/weight-sharing.config
Configure(scheduler='cos', LR=0.025, eta_min=0.001, epochs=100, warmup=0, optim='SGD', decay=0.0005, momentum=0.9, nesterov=True, criterion='Softmax', batch_size=64, test_batch_size=512, class_num=10, xshape=(1, 3, 32, 32))
||||||| cifar10    ||||||| Search-Loader-Num=391, Valid-Loader-Num=49, batch size=64
||||||| cifar10    ||||||| Config=Configure(scheduler='cos', LR=0.025, eta_min=0.001, epochs=100, warmup=0, optim='SGD', decay=0.0005, momentum=0.9, nesterov=True, criterion='Softmax', batch_size=64, test_batch_size=512, class_num=10, xshape=(1, 3, 32, 32))
search space : ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
model config : Configure(name='generic', C=16, N=5, max_nodes=4, num_classes=10, space=['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x

In [5]:
real_input_metrics = ['zico', 'snip', 'grasp', 'te_nas', 'gradsign']

def search_find_best(xargs, xloader, network, n_samples = None, archs = None):
    logger.log("Searching with {}".format(xargs.zero_shot_score.lower()))
    score_fn_name = "compute_{}_score".format(xargs.zero_shot_score.lower())
    score_fn = globals().get(score_fn_name)
    input_, target_ = next(iter(xloader))
    resolution = input_.size(2)
    batch_size = input_.size(0)
    zero_shot_score_dict = None
    arch_list = []
    trainloader = train_loader if xargs.zero_shot_score.lower() in real_input_metrics else None
    
    if archs is None and n_samples is not None:
        network.train()
        all_time = []
        all_mem = []
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        for i in tqdm.tqdm(range(n_samples)):
            # random sampling
            arch = network.random_genotype(True)

            start.record()
            torch.cuda.reset_peak_memory_stats()

            info_dict = score_fn.compute_nas_score(network, gpu=xargs.gpu, trainloader=trainloader, resolution=resolution, batch_size=batch_size)

            end.record()
            torch.cuda.synchronize()
            all_time.append(start.elapsed_time(end))
#             all_mem.append(torch.cuda.max_memory_reserved())
            all_mem.append(torch.cuda.max_memory_allocated())
            torch.cuda.empty_cache()

            arch_list.append(arch)
            if zero_shot_score_dict is None: # initialize dict
                zero_shot_score_dict = dict()
                for k in info_dict.keys():
                    zero_shot_score_dict[k] = []
            for k, v in info_dict.items():
                zero_shot_score_dict[k].append(v)

        logger.log("------Runtime------")
        logger.log("All: {:.5f} ms".format(np.mean(all_time)))
        logger.log("------Avg Mem------")
        logger.log("All: {:.5f} GB".format(np.mean(all_mem)/1e9))
        logger.log("------Max Mem------")
        logger.log("All: {:.5f} GB".format(np.max(all_mem)/1e9))
        
    elif archs is not None and n_samples is None:
        network.train()
        all_time = []
        all_mem = []
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        for arch in tqdm.tqdm(archs):
            network.arch_cache = arch

            start.record()
            torch.cuda.reset_peak_memory_stats()

            info_dict = score_fn.compute_nas_score(network, gpu=xargs.gpu, trainloader=trainloader, resolution=resolution, batch_size=batch_size)

            end.record()
            torch.cuda.synchronize()
            all_time.append(start.elapsed_time(end))
#             all_mem.append(torch.cuda.max_memory_reserved())
            all_mem.append(torch.cuda.max_memory_allocated())
            torch.cuda.empty_cache()

            arch_list.append(arch)
            if zero_shot_score_dict is None: # initialize dict
                zero_shot_score_dict = dict()
                for k in info_dict.keys():
                    zero_shot_score_dict[k] = []
            for k, v in info_dict.items():
                zero_shot_score_dict[k].append(v)

        logger.log("------Runtime------")
        logger.log("All: {:.5f} ms".format(np.mean(all_time)))
        logger.log("------Avg Mem------")
        logger.log("All: {:.5f} GB".format(np.mean(all_mem)/1e9))
        logger.log("------Max Mem------")
        logger.log("All: {:.5f} GB".format(np.max(all_mem)/1e9))
        
    return arch_list, zero_shot_score_dict

In [None]:
# ######### search across random N archs #########
# archs, results = search_find_best(xargs, train_loader, network, n_samples=3000)

# ######### search across N archs uniformly sampled according to test acc. #########
# def uniform_sample_archs(network, xargs, api, n_samples=1000, dataset='cifar10'):
#     arch = network.random_genotype(True)
#     search_space = get_search_spaces(xargs.search_space, "nats-bench")
#     archs = arch.gen_all(search_space, xargs.max_nodes, False)
    
#     def get_results_from_api(api, arch, dataset='cifar10'):
#         dataset_candidates = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
#         assert dataset in dataset_candidates
#         index = api.query_index_by_arch(arch)
#         api._prepare_info(index)
#         archresult = api.arch2infos_dict[index]['200']
#         if dataset == 'cifar10-valid':
#             acc = archresult.get_metrics(dataset, 'x-valid', iepoch=199, is_random=False)['accuracy']
#         else:
#             acc = archresult.get_metrics(dataset, 'ori-test', iepoch=199, is_random=False)['accuracy']
#         return acc

#     accs = []
#     for a in archs:
#         accs.append(get_results_from_api(api, a, dataset))
#     interval = len(archs) // n_samples
#     sorted_indices = np.argsort(accs)
#     new_archs = []
#     for i, idx in enumerate(sorted_indices):
#         if i % interval == 0:
#             new_archs.append(archs[idx])
#     archs = new_archs
    
#     return archs

# if os.path.exists("./tss_uniform_arch.pickle"):
#     with open("./tss_uniform_arch.pickle", "rb") as fp:
#         uniform_archs = pickle.load(fp)
# else:
#     uniform_archs = uniform_sample_archs(network, xargs, api, 1000, 'cifar10')
#     with open("./tss_uniform_arch.pickle", "wb") as fp:
#         pickle.dump(uniform_archs, fp)
# archs, results = search_find_best(xargs, train_loader, network, archs=uniform_archs)


######### search across all archs #########
def generate_all_archs(network, xargs):
    arch = network.random_genotype(True)
    search_space = get_search_spaces(xargs.search_space, "nats-bench")
    archs = arch.gen_all(search_space, xargs.max_nodes, False)
    return archs

all_archs = generate_all_archs(network, xargs)
archs, results = search_find_best(xargs, train_loader, network, archs=all_archs)

Searching with gradsign


  4%|▍         | 652/15625 [19:53<6:46:37,  1.63s/it] 

In [None]:
def get_results_from_api(api, arch, dataset='cifar10'):
    dataset_candidates = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
    assert dataset in dataset_candidates
    index = api.query_index_by_arch(arch)
    api._prepare_info(index)
    archresult = api.arch2infos_dict[index]['200']
    
    if dataset == 'cifar10-valid':
        acc = archresult.get_metrics(dataset, 'x-valid', iepoch=199, is_random=False)['accuracy']
    else:
        acc = archresult.get_metrics(dataset, 'ori-test', iepoch=199, is_random=False)['accuracy']
    flops = archresult.get_compute_costs(dataset)['flops']
    params = archresult.get_compute_costs(dataset)['params']
    
    return acc, flops, params

api_valid_accs, api_flops, api_params = [], [], []
for a in archs:
    valid_acc, flops, params = get_results_from_api(api, a, 'cifar10')
#     valid_acc, flops, params = get_results_from_api(api, a, 'cifar100')
#     valid_acc, flops, params = get_results_from_api(api, a, 'ImageNet16-120')
    api_valid_accs.append(valid_acc)
    api_flops.append(flops)
    api_params.append(params)

In [None]:
fig_scale = 1.1

rank_agg = None
l = len(api_flops)
rank_agg = np.log(stats.rankdata(api_flops) / l)
for k in results.keys():
    if rank_agg is None:
        rank_agg = np.log( stats.rankdata(results[k]) / l)
    else:
        rank_agg = rank_agg + np.log( stats.rankdata(results[k]) / l)

        
best_idx = np.argmax(rank_agg)

best_arch, acc = archs[best_idx], api_valid_accs[best_idx]
if api is not None:
    print("{:}".format(api.query_by_arch(best_arch, "200")))
    

x = stats.rankdata(rank_agg)
y = stats.rankdata(api_valid_accs)
kendalltau = stats.kendalltau(x, y)
spearmanr = stats.spearmanr(x, y)
pearsonr = stats.pearsonr(x, y)
print("aggregated: {}\t{}\t{}\t".format(kendalltau[0], pearsonr[0], spearmanr[0]))
plt.figure(figsize=(4*fig_scale,3*fig_scale))
plt.scatter(x, y, linewidths=0.1)
plt.scatter(x[best_idx], y[best_idx], c="r", linewidths=0.1)
plt.title("Rank_agg")
plt.show()


# for each
metrics = {'FLOPs':api_flops, 'Params':api_params}
for k, v in results.items():
    metrics[k] = v
for k in metrics.keys():
    x = stats.rankdata(metrics[k])
    y = stats.rankdata(api_valid_accs)
    kendalltau = stats.kendalltau(x, y)
    spearmanr = stats.spearmanr(x, y)
    pearsonr = stats.pearsonr(x, y)
    print("{}: {}\t{}\t{}\t".format(k, kendalltau[0], pearsonr[0], spearmanr[0]))
    plt.figure(figsize=(4*fig_scale,3*fig_scale))
    plt.scatter(x, y, linewidths=0.1)
    plt.scatter(x[best_idx], y[best_idx], c="r", linewidths=0.1)
    plt.title("Rank_{}".format(k))
    plt.show()

In [None]:
print(metrics['gradsign'])

In [None]:
for a,b in zip(archs, metrics['gradsign']):
    print(a,b)