Skip to content

Commit

Permalink
rm PD ; update NAS-Bench-102 baselines
Browse files Browse the repository at this point in the history
  • Loading branch information
D-X-Y committed Dec 23, 2019
1 parent 31c6e2b commit 729ce13
Show file tree
Hide file tree
Showing 25 changed files with 66 additions and 972 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -118,3 +118,5 @@ cx.sh

NAS-Bench-102-v1_0.pth
lib/NAS-Bench-102-v1_0.pth
others/TF
scripts-search/l2s-algos
4 changes: 2 additions & 2 deletions NAS-Bench-102.md
Expand Up @@ -14,11 +14,11 @@ Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.

### Preparation and Download

The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan].
The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w).
You can move it to anywhere you want and send its path to our API for initialization.
- v1.0: `NAS-Bench-102-v1_0-e61699.pth`, where `e61699` is the last six digits for this file.

The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan].
The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ).
It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-102 or similar NAS datasets or training models by yourself, you need these data.

## How to Use NAS-Bench-102
Expand Down
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -13,7 +13,7 @@ More NAS resources can be found in [Awesome-NAS](https://github.com/D-X-Y/Awesom

## Requirements and Preparation

Please install `PyTorch>=1.1.0`, `Python>=3.6`, and `opencv`.
Please install `PyTorch>=1.2.0`, `Python>=3.6`, and `opencv`.

The CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`.
Some methods use knowledge distillation (KD), which require pre-trained models. Please download these models from [Google Driver](https://drive.google.com/open?id=1ANmiYEGX-IQZTfH8w0aSpj-Wypg-0DR-) (or train by yourself) and save into `.latent-data`.
Expand Down
2 changes: 1 addition & 1 deletion exps/NAS-Bench-102/main.py
Expand Up @@ -213,7 +213,7 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se


def generate_meta_info(save_dir, max_node, divide=40):
aa_nas_bench_ss = get_search_spaces('cell', 'aa-nas')
aa_nas_bench_ss = get_search_spaces('cell', 'nas-bench-102')
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2)))

Expand Down
31 changes: 13 additions & 18 deletions exps/algos/DARTS-V1.py
Expand Up @@ -17,6 +17,7 @@
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_102_api import NASBench102API as API


def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
Expand Down Expand Up @@ -144,6 +145,11 @@ def main(xargs):
flop, param = get_model_infos(search_model, xshape)
#logger.log('{:}'.format(search_model))
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))

last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
Expand All @@ -165,15 +171,16 @@ def main(xargs):
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

# start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))

search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# check the best accuracy
Expand Down Expand Up @@ -204,29 +211,17 @@ def main(xargs):
if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None:
logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()

logger.log('\n' + '-'*100)
# check the performance from the architecture dataset
#if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
# logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
#else:
# nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
# geno = genotypes[total_epoch-1]
# logger.log('The last model is {:}'.format(geno))
# info = nas_bench.query_by_arch( geno )
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
# else : logger.log('{:}'.format(info))
# logger.log('-'*100)
# geno = genotypes['best']
# logger.log('The best model is {:}'.format(geno))
# info = nas_bench.query_by_arch( geno )
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
# else : logger.log('{:}'.format(info))
logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
logger.close()


Expand Down
6 changes: 3 additions & 3 deletions exps/algos/R_EA.py
Expand Up @@ -59,9 +59,9 @@ def train_and_eval(arch, nas_bench, extra_info):
if nas_bench is not None:
arch_index = nas_bench.query_index_by_arch( arch )
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
info = nas_bench.arch2infos[ arch_index ]
_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
#import pdb; pdb.set_trace()
info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True)
import pdb; pdb.set_trace()
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
else:
# train a model from scratch.
raise ValueError('NOT IMPLEMENT YET')
Expand Down
1 change: 1 addition & 0 deletions lib/models/__init__.py
Expand Up @@ -36,6 +36,7 @@ def get_cell_based_tiny_net(config):
def get_search_spaces(xtype, name):
if xtype == 'cell':
from .cell_operations import SearchSpaceNames
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
return SearchSpaceNames[name]
else:
raise ValueError('invalid search-space type is {:}'.format(xtype))
Expand Down
11 changes: 6 additions & 5 deletions lib/models/cell_operations.py
Expand Up @@ -16,12 +16,13 @@
'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine),
}

CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
AA_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']

SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'aa-nas' : AA_NAS_BENCHMARK,
'full' : sorted(list(OPS.keys()))}
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'aa-nas' : NAS_BENCH_102,
'nas-bench-102': NAS_BENCH_102,
'full' : sorted(list(OPS.keys()))}


class ReLUConvBN(nn.Module):
Expand Down
36 changes: 31 additions & 5 deletions lib/nas_102_api/api.py
Expand Up @@ -129,6 +129,27 @@ def arch(self, index):
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
return copy.deepcopy(self.meta_archs[index])

def get_more_info(self, index, dataset, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index]
if dataset == 'cifar10-valid':
train_info = archresult.get_metrics(dataset, 'train', is_random=True)
valid_info = archresult.get_metrics(dataset, 'x-valid', is_random=True)
test__info = archresult.get_metrics(dataset, 'ori-test', is_random=True)
total = train_info['iepoch'] + 1
return {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'train-all-time': train_info['all_time'],
'valid-loss' : valid_info['loss'],
'valid-accuracy': valid_info['accuracy'],
'valid-all-time': valid_info['all_time'],
'valid-per-time': valid_info['all_time'] / total,
'test-loss' : test__info['loss'],
'test-accuracy' : test__info['accuracy']}
else:
raise ValueError('coming soon...')

def show(self, index=-1):
if index < 0: # show all architectures
print(self)
Expand Down Expand Up @@ -367,23 +388,28 @@ def get_eval_set(self):
def get_train(self, iepoch=None):
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
if self.train_times is not None: xtime = self.train_times[iepoch]
else : xtime = None
if self.train_times is not None:
xtime = self.train_times[iepoch]
atime = sum([self.train_times[i] for i in range(iepoch+1)])
else: xtime, atime = None, None
return {'iepoch' : iepoch,
'loss' : self.train_losses[iepoch],
'accuracy': self.train_acc1es[iepoch],
'time' : xtime}
'cur_time': xtime,
'all_time': atime}

def get_eval(self, name, iepoch=None):
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
xtime = self.eval_times['{:}@{:}'.format(name,iepoch)]
else: xtime = None
atime = sum([self.eval_times['{:}@{:}'.format(name,i)] for i in range(iepoch+1)])
else: xtime, atime = None, None
return {'iepoch' : iepoch,
'loss' : self.eval_losses['{:}@{:}'.format(name,iepoch)],
'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)],
'time' : xtime}
'cur_time': xtime,
'all_time': atime}

def get_net_param(self):
return self.net_state_dict
Expand Down
3 changes: 0 additions & 3 deletions others/paddlepaddle/.gitignore

This file was deleted.

118 changes: 0 additions & 118 deletions others/paddlepaddle/README.md

This file was deleted.

3 changes: 0 additions & 3 deletions others/paddlepaddle/lib/models/__init__.py

This file was deleted.

0 comments on commit 729ce13

Please sign in to comment.