This repository contains the necessary resources and instructions to reproduce the results presented in our paper "On Latency Predictors for Neural Architecture Search." Our work introduces a comprehensive suite of latency prediction tasks and a novel latency predictor, NASFLAT, that significantly outperforms existing methods in hardware-aware NAS.
- Ensure the
env_setup.py
script is executed correctly for environment setup. Modify the script paths as necessary for your system.
- Download the NDS dataset from here and place it in the
nas_embedding_suite
folder with the structureNDS/nds_data/*.json
. - Download and unzip
nasflat_embeddings_04_03_24.zip
into./nas_embedding_suite/
from Google Drive.
- Training and testing commands for the predictors are in
./correlation_trainer/large_run_slurms/unified_joblist.log
. - To reproduce results for MultiPredict and HELP, refer to
multipredict_unified_joblist.log
andhelp_unified_joblist.log
. - For SLURM setups, use
parallelized_executor.sh
, adapting it as necessary for your environment. These commands can also be adjusted for non-SLURM execution.
Below are specific example commands that demonstrate how to execute various processes within the framework. These examples cover training from scratch, utilizing supplementary encodings, transferring predictors between spaces, and running NAS on a given search space.
The files referenced below are located at correlation_trainer and nas_search.
python main_trf.py --seed 42 --name_desc study_6_5_f_zcp --sample_sizes 800 --task_index 5 --representation adj_gin_zcp --num_trials 3 --transfer_sample_sizes 20 --transfer_lr 0.001 --transfer_epochs 30 --transfer_hwemb --space fbnet --gnn_type ensemble --sampling_metric a2vcatezcp --ensemble_fuse_method add
python main_trf.py --seed 42 --name_desc arch_abl --sample_sizes 512 --representation adj_gin --num_trials 7 --transfer_sample_sizes 5 10 20 --transfer_lr 0.0001 --transfer_epochs 20 --transfer_hwemb --hwemb_to_mlp --task_index 4 --space nb201
python main_trf.py --seed 42 --name_desc arch_abl --sample_sizes 512 --representation adj_gin --num_trials 7 --transfer_sample_sizes 5 10 20 --transfer_lr 0.0001 --transfer_epochs 20 --transfer_hwemb --task_index 4 --space nb201
python main_trf.py --seed 42 --name_desc arch_abl --sample_sizes 512 --representation adj_gin --num_trials 7 --transfer_sample_sizes 5 10 20 --transfer_lr 0.0001 --transfer_epochs 20 --task_index 4 --space nb201
python main_trf.py --seed 42 --name_desc arch_abl --sample_sizes 512 --representation adj_gin --num_trials 7 --transfer_sample_sizes 5 10 20 --transfer_lr 0.0001 --transfer_epochs 20 --hwemb_to_mlp --task_index 4 --space nb201
python main_trf.py --seed 42 --name_desc study_6_3_1_t2 --sample_sizes 512 --representation adj_gin --num_trials 7 --transfer_sample_sizes 5 10 15 20 30 --transfer_lr 0.0001 --transfer_epochs 20 --transfer_hwemb --task_index 1 --space nb201 --gnn_type ensemble --sampling_metric [random/params/arch2vec/cate/zcp/a2vcatezcp/latency]
python main_trf.py --seed 42 --name_desc study_6_3_2 --sample_sizes 512 --representation [adj_gin/adj_gin_arch2vec/adj_gin_zcp/adj_gin_a2vcatezcp/adj_gin_cate] --num_trials 5 --transfer_sample_sizes 5 10 20 --transfer_lr 0.0001 --transfer_epochs 20 --transfer_hwemb --task_index 3 --space nb201 --gnn_type ensemble --sampling_metric a2vcatezcp
All scripts can be found at correlation_trainer/large_run_slurms/multipredict_unified_joblist.log
python fsh_advanced_training.py --name_desc multipredict_baseline_r --task_index 0 --space fbnet --emb_transfer_samples 16
All scripts can be found at correlation_trainer/large_run_slurms/help_unified_joblist.log
python main.py --gpu 0 --mode 'meta-train' --seed 42 --num_trials 3 --name_desc 'help_baselines_r' --num_meta_train_sample 4000 --mc_sampling 10 --num_episodes 2000 --task_index 5 --search_space fbnet --num_samples 10
If you use the code or data in your research, please use the following BibTex entry:
@misc{akhauri2024latency,
title={On Latency Predictors for Neural Architecture Search},
author={Yash Akhauri and Mohamed S. Abdelfattah},
year={2024},
eprint={2403.02446},
archivePrefix={arXiv},
primaryClass={cs.LG}
}