Some of this code is Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [1]:
# Uncomment for just the 108 epoch dataset (~500mb)
# !curl -O https://storage.googleapis.com/nasbench/nasbench_only108.tfrecord
# Uncomment for the full dataset (~2gb)
# !curl -O https://storage.googleapis.com/nasbench/nasbench_full.tfrecord

In [2]:
import copy
import math
from typing import Set
from typing import Callable

from trials.Constants import *
from trials.Evolution import *
from trials.ModelSpec import *
from trials.Search import *
from trials.Selection import *
from trials.Utilities import *

from IPython.core.display import clear_output




Loading dataset from file... This may take a few minutes...
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
Loaded dataset in 35 seconds


In [39]:
specs = [get_spec(hsh) for hsh in nasbench.hash_iterator()]

In [40]:
ALL_SPECS = dict()
for spec in specs:
    ALL_SPECS[spec.get_hash()] = spec

In [41]:
len(ALL_SPECS)

423624

In [None]:
def search(
        max_time: float,
        num_best: int,
        num_epochs: int,
        initial_population: List[SpecWrapper],
        mut_fn: Callable[[List[SpecWrapper]], List[SpecWrapper]],
        sel_fn: Callable[[List[SpecWrapper]], List[SpecWrapper]],
        drp_fn: Callable[[List[SpecWrapper]], List[SpecWrapper]]
    ):


    # print(max_time)
    # print(num_best)
    # print(num_epochs)

    sel_best = sel_best_fn(num_best)

    def run_epoch(epoch_num: int, population: List[SpecWrapper]):
        reset_trial_stats(RNG_SEED + epoch_num)

        # initialize our population list, update budget counters
        # this also effectively "trains" the initial population
        population = [get_spec(ind.get_hash()) for ind in population]

        # desired size of the population
        p_size = len(population)

        # list of hashes that we have previously evaluated which can be skipped in the future
        done: Set[str] = set()

        # running cumulative time total of all epochs
        cur_time: float = 0

        def print_update():
            nonlocal cur_time, max_time, done
            print(f"{cur_time/1000:0.2f}k/{max_time/1000:0.0f}k ({(cur_time/max_time)*100:0.2f}%) seconds simulated, {len(done)} unique models")

        # update done, adding any new hashes to our set
        def update_done(items: List[SpecWrapper]):
            nonlocal cur_time
            done.update(map(lambda x: x.get_hash(), items))
            cur_time = sum([ind.get_data().train_time for ind in map(lambda x: get_spec(x), done)])

        update_done(population)
        print_update()

        while cur_time < max_time:
            # drop some candidates base on the specified function
            population = drp_fn(population)
            assert len(population) > 0
            update_done(population)

            # number of new specs to generate through mutation or crossover
            num_new = p_size - len(population)
            new_specs: List[SpecWrapper] = []
            while len(new_specs) < num_new:
                # select some candidates to mutate
                candidates = sel_fn(population)
                # mutate the candidates with the fn
                candidates = mut_fn(candidates)

                # only "evaluate" candidates who were not evaluate
                # so if we hit a duplicate candidate, we should not
                # add more training time, as it is wasteful
                cand_hashes = [cand.get_hash() for cand in candidates]
                cand_hashes = [hsh for hsh in cand_hashes if hsh not in done]

                # remove dupe hashes
                candidates = [get_spec(hsh) for hsh in cand_hashes]

                # update the list of new specs
                new_specs = [*new_specs, *candidates]

            population = [*population, *new_specs][:p_size]

            # [ind.get_data() for ind in population]
            update_done(population)

            print_update()

        best = sel_best(population)
        best.sort()

        abs_best = best[-1].get_data()
        print(f'Best in trial -- Test: {abs_best.test_accuracy:0.7f}, Valid: {abs_best.valid_accuracy:0.7f}')

        return population, best, done

    results: [List[List[SpecWrapper], List[SpecWrapper], Set[str]]] = []
    for epoch_num in range(num_epochs):
        print('-'*50)
        population, best, done = run_epoch(epoch_num, initial_population)
        results.append([population, best, done])
        print(f'Finished evaluation, {len(done)} models evaluated')
        print('-'*50)

    return results



In [38]:
# 365 days of GPU time
MAX_TIME = 60*60*24*31
N_BEST = 10
N_POP = 250
N_EPOCH = 1
MUT_RATE = 1

INITIAL_POPULATION = [random_spec() for idx in range(N_POP)]

trials = {
    'test': {
        "max_time": MAX_TIME,
        "num_best": N_BEST,
        "num_epochs": N_EPOCH,
        "initial_population": copy.deepcopy(INITIAL_POPULATION),
        "mut_fn": mutate_fn(1.0),
        "sel_fn": sel_tournament_fn(int(N_POP*0.25), int(N_POP*0.1)),
        "drp_fn": drop_worst_fn(N_POP*0.1)
    }
}

for key, values in trials.items():
    print(key)
    search(**values)



test
--------------------------------------------------
Reset stats.
476.40k/2678k (17.79%) seconds simulated, 250 unique models
516.83k/2678k (19.30%) seconds simulated, 275 unique models
549.75k/2678k (20.53%) seconds simulated, 299 unique models
592.03k/2678k (22.10%) seconds simulated, 324 unique models
636.27k/2678k (23.76%) seconds simulated, 349 unique models
680.67k/2678k (25.41%) seconds simulated, 373 unique models
731.42k/2678k (27.31%) seconds simulated, 398 unique models
777.08k/2678k (29.01%) seconds simulated, 423 unique models
822.38k/2678k (30.70%) seconds simulated, 448 unique models
869.34k/2678k (32.46%) seconds simulated, 473 unique models
906.39k/2678k (33.84%) seconds simulated, 498 unique models
947.43k/2678k (35.37%) seconds simulated, 522 unique models
988.68k/2678k (36.91%) seconds simulated, 547 unique models
1035.75k/2678k (38.67%) seconds simulated, 572 unique models
1079.89k/2678k (40.32%) seconds simulated, 596 unique models
1125.84k/2678k (42.03%) secon

KeyboardInterrupt: 