Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is NATS Extension of NAS_201 bench #48

Closed
Mars-204 opened this issue Dec 12, 2022 · 4 comments
Closed

Is NATS Extension of NAS_201 bench #48

Mars-204 opened this issue Dec 12, 2022 · 4 comments

Comments

@Mars-204
Copy link

I was working with NAS_201 bench earlier and now am shifting to NATS bench. I have the following doubts:

  • Will the index of architecture obtained with NAS_201 and NATS bench same for a given arch? Because when using both NAS_201 is giving 12804 index while NATS bench is returning -1. Also the index of NAS_201 provides different config arch when used as index in NATS.
  • While obtaining weights I am getting an empty return when using the NAS_201 index.

Following is my implementation. I am using benchmark file with sss.

api = create(d, 'sss', fast_mode=False, verbose=True)
index = api.query_index_by_arch(convert_naslib_to_str(best_arch))
config = api.get_net_config(index, 'cifar10')
best_arch = get_cell_based_tiny_net(config)
logger.info("Queried results ({}): {}".format(metric, best_arch))
params = api.get_net_param(index, 'cifar10', None)
best_arch.load_state_dict(next(iter(params.values())))

@D-X-Y
Copy link
Owner

D-X-Y commented Dec 16, 2022

Hi @Mars-204 , thanks for your interests.

For the tss space, it is the same index. For the sss (size search space), NAS-Bench-201 does not have the size search space. For the examples, you should compare api = create(d, 'tss', fast_mode=False, verbose=True) with NAS-Bench-201.

@Mars-204
Copy link
Author

Hi @D-X-Y , thanks for the reply. It works for tss as you mentioned.

But when I am trying to obtain the pre-trained weights for this index I am getting a an empty dic in return. Currently I am using the benchmark compressed pickle for tss for the same. Should I be using any other files for the pre-trained weights?

d = pickle_load('/work/ws-tmp/g059997-naslib/g059997-naslib-1667607005/NASLib_mod/naslib/NATS-bench/Copy of NATS-tss-v1_0-3ffb9.pickle.pbz2')
api = create(d, 'tss', fast_mode=False, verbose=True)
index = api.query_index_by_arch(convert_naslib_to_str(best_arch))
config = api.get_net_config(index, 'cifar10')
best_arch = get_cell_based_tiny_net(config)
params = api.get_net_param(index, 'cifar10', None)
best_arch.load_state_dict(next(iter(params.values())))

@D-X-Y
Copy link
Owner

D-X-Y commented Dec 20, 2022

The network parameters are contained in full archive files (please follow the instructions at https://github.com/D-X-Y/NATS-Bench#preparation-and-download). Note that those files are pretty large.

@Mars-204
Copy link
Author

Mars-204 commented Jan 7, 2023

Thanks for the response. It is working.

@Mars-204 Mars-204 closed this as completed Jan 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants