Some of this code is Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [4]:
# Uncomment the below line to download the dataset (approx. 2gb)
# !curl -O https://storage.googleapis.com/nasbench/nasbench_full.tfrecord

In [5]:
# Initialize the NASBench object which parses the raw data into memory
# This may take a few minutes, about 2 on my machine
from typing import Set

import numpy as np
import random

import pandas as pd

from nasbench import api
from typing import List
from nasbench.api import ModelSpec
from dataclasses import dataclass
from functools import lru_cache
from tqdm.notebook import tqdm
from IPython.display import clear_output

In [6]:
nasbench = api.NASBench("nasbench_full.tfrecord")

Loading dataset from file... This may take a few minutes...
Loaded dataset in 118 seconds


In [31]:
INPUT = "input"
OUTPUT = "output"
CONV3X3 = "conv3x3-bn-relu"
CONV1X1 = "conv1x1-bn-relu"
MAXPOOL3X3 = "maxpool3x3"

NUM_VERTICES = 7
MAX_EDGES = 9

EDGE_SPOTS = NUM_VERTICES * (NUM_VERTICES - 1) / 2  # Upper triangular matrix
OP_SPOTS = NUM_VERTICES - 2  # Input/output vertices are fixed

ALLOWED_OPS = [CONV3X3, CONV1X1, MAXPOOL3X3]
ALLOWED_EDGES = [0, 1]  # Binary adjacency matrix

RNG_SEED = 42

In [8]:
@dataclass
class ModelData:
    """
    Wraps up resulting data from NASBench.query in a class for code cleanliness.
    """
    hash: str
    matrix: List[List[int]]
    operations: List[str]
    parameters: int
    train_time: float
    train_accuracy: float
    valid_accuracy: float
    test_accuracy: float


class SpecWrapper(ModelSpec):
    """
    Wraps a model to allow easier access to the resulting data of the model.
    """

    def __lt__(self, other: "SpecWrapper"):
        if not (isinstance(other, SpecWrapper)):
            return False

        return self.get_data().test_accuracy < other.get_data().test_accuracy

    @lru_cache(maxsize=1)
    def get_data(self):
        """
        Get resultant data from NASBench.
        The LRU cache ensures we don"t add time to the budget counters multiple times.
        :return:
        """
        data = nasbench.query(self)
        return ModelData(
            self.hash_spec(nasbench.config["available_ops"]),
            data["module_adjacency"],
            data["module_operations"],
            data["trainable_parameters"],
            data["training_time"],
            data["train_accuracy"],
            data["validation_accuracy"],
            data["test_accuracy"]
        )

    def __repr__(self):
        return self.hash_spec(nasbench.config["available_ops"])

In [50]:
# we'll precache all the models for convenience of code later, then we don't have to build them at runtime
# and can rely on the lru cache in the get_data method to not double up on queries. this way we don't have
# to maintain a list of which models we have queried so far. this is a few hundred mb of ram, not significant
def precache_specs():
    n_specs = len(list(nasbench.hash_iterator()))
    all_specs = tqdm(map(lambda x: nasbench.get_metrics_from_hash(x), nasbench.hash_iterator()), total=n_specs, unit='spec')
    all_specs = [SpecWrapper(matrix=spec[0]["module_adjacency"], ops=spec[0]["module_operations"]) for spec in all_specs]
    return all_specs

In [None]:
all_specs = precache_specs()

In [51]:
def random_spec() -> SpecWrapper:
    return random.choice(all_specs)

In [47]:
def sel_best(population: List[SpecWrapper], k: int) -> List[SpecWrapper]:
    """
    Select the top k candidates amongst the entire population.
    :param population: The population to select from.
    :param k: The number of top individuals to select.
    :return: The selected individuals.
    """
    # clone the list so we don't modify the passed list
    population = list(population)
    population.sort()
    return population[-k:]


def sel_random(population: List[SpecWrapper], k: int) -> List[SpecWrapper]:
    """
    Select k individuals from the population at random.
    :param population: The population to select from.
    :param k: The number of individuals to select.
    :return: The selected individuals.
    """
    return random.sample(population, k)


def sel_tournament(population: List[SpecWrapper], k: int, n: int) -> List[SpecWrapper]:
    """
    Select the n best individuals from amongst k randomly selected candidates.
    :param population: The population to select from.
    :param k: The number of candidates to randomly select.
    :param n: The number of top candidates to be returned.
    :return: The selected individuals.
    """
    candidates = sel_random(population, k)
    return sel_best(candidates, n)


def drop_worst(population: List[SpecWrapper], k: int) -> List[SpecWrapper]:
    """
    Drop the k worst candidates and return the remaining population.
    :param population: The population to drop from.
    :param k: Number of worst candidates to drop.
    """
    # clone the list so we don't modify the passed list
    population = list(population)
    population.sort()
    return population[k:]



In [26]:
def run_random_search(k: int, max_time: float):
    """
    Runs a simulated random search for a fixed amount of time. Actual time taken will be
    relatively quick, only the simulated time is considered. Will return all evaluated models,
    along with the k best models found.
    :param k: Number of best models to return
    :param max_time_budget: The amount of simulated time to expend.
    :return: All models evaluated, along with the best k models.
    """
    # resetting seeds at the top of the search functions is required for reproducibility
    np.random.seed(RNG_SEED)
    random.seed(RNG_SEED)

    # nasbench tracks time taken with each query, effectively simulating full
    # training, as we can query how long into the training we would be if we
    # were training ourselves. we reset the counters at the top of each experiment
    nasbench.reset_budget_counters()

    best: List[SpecWrapper] = []
    models: List[SpecWrapper] = []
    unique_hashes: Set[str] = set()
    collisions: int = 0

    time_spent, _ = nasbench.get_budget_counters()
    while time_spent < max_time:
        spec = random_spec()
        spec_hash = spec.hash_spec(nasbench.config["available_ops"])

        if spec_hash in unique_hashes:
            collisions += 1
            continue
        else:
            unique_hashes.add(spec_hash)

        models.append(spec)
        best = sel_best(models, k)
        time_spent, _ = nasbench.get_budget_counters()
        clear_output(wait=True)
        print(f"{time_spent/1000:0.2f}/{max_time/1000:0.0f}k ({(time_spent/max_time)*100:0.2f}%) seconds simulated")

    print(f"{len(unique_hashes)} unique models during random search.")
    print(f"{collisions} collisions during random search.")

    return models, best



In [30]:
all, best = run_random_search(10, 5e8)

502534.95/500000k (100.51%) seconds simulated
248 unique models during random search.
0 collisions during random search.


In [34]:
adf = pd.DataFrame(list(map(lambda x: x.get_data(), all)))
adf = adf.drop(columns=['matrix', 'operations'])
adf

Unnamed: 0,hash,parameters,train_time,train_accuracy,valid_accuracy,test_accuracy
0,e07831948fdf884f89dc7eadd69c4592,5032842,1737.491943,0.857772,0.780048,0.776042
1,5887b483c7eba6fb0e420a29f7007df2,1560375,976.177002,0.933193,0.837340,0.827925
2,695db6c37f60c2fee65c073ba60a88ea,1749386,880.635986,0.442508,0.430188,0.434796
3,8c0da3ffb5fc98ee51397f741b98ef1f,20510346,3132.138916,0.995092,0.861378,0.851162
4,e6b06f0daa891a06fe96da1c6daad856,3831434,2292.532959,0.999099,0.856270,0.849059
...,...,...,...,...,...,...
243,411f76b524a5cae02ac47d0e55209fb1,15037834,2688.839844,1.000000,0.940505,0.932191
244,fd24de2c46bab257f6a756b7a319a9bb,38936714,4517.019043,1.000000,0.942408,0.934195
245,6270d259a8dea747d3854e620c11b453,23295370,3486.271973,1.000000,0.935998,0.925080
246,6a157e9ea9f72c8c24d7259ce97c7b1e,23131530,3474.256836,1.000000,0.942608,0.938802


In [35]:
bdf = pd.DataFrame(list(map(lambda x: x.get_data(), best)))
bdf = bdf.drop(columns=['matrix', 'operations'])
bdf

Unnamed: 0,hash,parameters,train_time,train_accuracy,valid_accuracy,test_accuracy
0,c38f726827d319f5b746975edba7d84f,6054282,1440.531006,1.0,0.931991,0.928786
1,32084efa106310985249a333673b37b9,5476618,2349.081055,1.0,0.940905,0.93109
2,8ed16854745becfe5c80f810b5c8fd76,31552906,4166.227051,1.0,0.939503,0.930789
3,1023dc2082850a44a31c3516d5ed8560,8555530,1647.256958,1.0,0.940204,0.933894
4,43b5ec90010e9da3b678156d50d49838,24169098,3596.868164,1.0,0.936999,0.934696
5,411f76b524a5cae02ac47d0e55209fb1,15037834,2688.839844,1.0,0.940505,0.932191
6,fd24de2c46bab257f6a756b7a319a9bb,38936714,4517.982422,1.0,0.9376,0.932692
7,6270d259a8dea747d3854e620c11b453,23295370,3486.841797,1.0,0.939103,0.932091
8,6a157e9ea9f72c8c24d7259ce97c7b1e,23131530,3468.629883,1.0,0.938301,0.935196
9,220d3ddac849eb7cd079405c7aaaccdf,8816266,1724.791016,1.0,0.940104,0.936098


In [None]:
all_specs = precache_specs()

In [38]:
def run_evolutionary_search(max_time: float, pop_size: int):
    # resetting seeds at the top of the search functions is required for reproducibility
    np.random.seed(RNG_SEED)
    random.seed(RNG_SEED)

    # nasbench tracks time taken with each query, effectively simulating full
    # training, as we can query how long into the training we would be if we
    # were training ourselves. we reset the counters at the top of each experiment
    nasbench.reset_budget_counters()

    population = sel_random(all_specs, pop_size)
    [p.get_data() for p in population] # update nasbench budgets

    time_spent, _ = nasbench.get_budget_counters()
    while time_spent < max_time:

        time_spent, _ = nasbench.get_budget_counters()

    return population, population

all, best = run_evolutionary_search(0, 10)

In [42]:
all.sort()
best.sort()

In [48]:
all = drop_worst(best, 1)

In [49]:
adf = pd.DataFrame(list(map(lambda x: x.get_data(), all)))
adf = adf.drop(columns=['matrix', 'operations'])
adf

Unnamed: 0,hash,parameters,train_time,train_accuracy,valid_accuracy,test_accuracy
0,2c60a08cadca7764ded5291311d4ec60,3031562,1015.057007,0.997196,0.873698,0.863482
1,46a057f9756748a3d5546b05c6bb137f,4705162,2500.507812,1.0,0.891126,0.888722
2,e96ce68c5c2fb2a34f9550fc043a3e75,2773816,1246.448975,0.9999,0.904748,0.897636
3,eafd21c3ca59ad9a72f637cda520e4c1,21547914,3183.962891,1.0,0.904547,0.901142
4,20b076959685b524b832ffac1ddc9fdd,3694666,1249.853027,1.0,0.919872,0.915264
5,cab83eedfd66bc3c48028cc7cd519c4f,3468426,1250.86499,1.0,0.928285,0.917668
6,573d39c8e8c80005bdf2d700a49e08ed,4166026,2026.727051,0.9999,0.930389,0.923177
7,4d9b210d90d8326eb22188475109a784,4849361,1423.86792,1.0,0.927985,0.921875
8,2379bf5524680737ba048a046950bfdc,41721738,4713.012695,1.0,0.942708,0.939002


In [44]:
bdf = pd.DataFrame(list(map(lambda x: x.get_data(), best)))
bdf = bdf.drop(columns=['matrix', 'operations'])
bdf

Unnamed: 0,hash,parameters,train_time,train_accuracy,valid_accuracy,test_accuracy
0,07e222992d9a0a54e18ae21a61fa97b3,5906570,1947.873047,0.999399,0.840645,0.839543
1,2c60a08cadca7764ded5291311d4ec60,3031562,1014.238037,0.998598,0.878105,0.871795
2,46a057f9756748a3d5546b05c6bb137f,4705162,2501.74707,1.0,0.892528,0.888321
3,e96ce68c5c2fb2a34f9550fc043a3e75,2773816,1243.06604,0.9998,0.90605,0.895533
4,eafd21c3ca59ad9a72f637cda520e4c1,21547914,3174.112793,1.0,0.915164,0.910557
5,20b076959685b524b832ffac1ddc9fdd,3694666,1249.853027,1.0,0.919872,0.915264
6,cab83eedfd66bc3c48028cc7cd519c4f,3468426,1247.676025,1.0,0.921374,0.919872
7,573d39c8e8c80005bdf2d700a49e08ed,4166026,2023.040039,1.0,0.926182,0.919571
8,4d9b210d90d8326eb22188475109a784,4849361,1423.86792,1.0,0.927985,0.921875
9,2379bf5524680737ba048a046950bfdc,41721738,4713.345703,1.0,0.944111,0.938201
