In [None]:
from network_architecture_optimization.genetic_algorithms.callbacks import OnGenerationCallback
from network_architecture_optimization.genetic_algorithms.teacher import Teacher
from network_architecture_optimization.mappers.network_builder import NetworkBuilder
from network_architecture_optimization.mappers.layer_description import LayerDescription
from dataclasses import fields
import pygad as pg
from sklearn.datasets import load_digits
import tensorflow as tf
import os
tf.get_logger().setLevel('INFO')

In [None]:
VERBOSITY = 0
CHROMOSOME_SIZE = 3 * len(fields(LayerDescription))
NUM_GENERATIONS = 10
SOLUTIONS_PER_POPULATION = 6

In [None]:
def run_experiment(name, ga_params):
    X, y = load_digits()['images'], load_digits()['target']
    network_builder = NetworkBuilder(verobse_level=VERBOSITY)
    teacher = Teacher(X, y, 3, network_builder)

    ga_instance = pg.GA(
        num_genes=CHROMOSOME_SIZE,
        sol_per_pop=SOLUTIONS_PER_POPULATION,
        num_generations=NUM_GENERATIONS,
        gene_space= range(7), #TODO make it depend on maximum number of genes in mapper
        gene_type=int,
        num_parents_mating=4,
        fitness_func=lambda solution, solution_idx: teacher.fitness_function(solution, solution_idx),
        on_generation=lambda ga_instance: OnGenerationCallback().on_generation(ga_instance),
        **ga_params
    )
    ga_instance.run()
    ga_instance.plot_fitness(title=name)

In [None]:
experiments = [{
    'name': 'Selection tests',
    'fixed_params': {
        'mutation_type': 'adaptive',
        'mutation_percent_genes': (0.6, 0.5),
    },
    'variable_params': {
        'parent_selection_type': ["sss", "rws", "random", "rank", "tournament", "sus"],
    }
}]

# -sss (for steady state selection)
# -rws (for roulette wheel selection)
# -sus (for stochastic universal selection)
# -rank (for rank selection)
# -random (for random selection)
# -tournament (for tournament selection).

In [46]:
from itertools import product

for experiment in experiments:
    params = experiment.setdefault('fixed_params', {})
    for combination in product(*experiment['variable_params'].values()):
        variable_params_combination = dict(zip(experiment['variable_params'].keys(), combination))
        run_experiment(experiment['name'], {
            **params,
            **variable_params_combination
        })

15/15 - 0s - loss: 2.3060 - accuracy: 0.0844 - 160ms/epoch - 11ms/step
15/15 - 0s - loss: 2.3055 - accuracy: 0.0844 - 141ms/epoch - 9ms/step
15/15 - 1s - loss: 2.3047 - accuracy: 0.0844 - 647ms/epoch - 43ms/step
15/15 - 0s - loss: 2.3066 - accuracy: 0.0844 - 128ms/epoch - 9ms/step
15/15 - 0s - loss: 2.3063 - accuracy: 0.0844 - 138ms/epoch - 9ms/step
15/15 - 0s - loss: 2.3062 - accuracy: 0.0844 - 140ms/epoch - 9ms/step
15/15 - 0s - loss: 1.5599 - accuracy: 0.6933 - 145ms/epoch - 10ms/step
15/15 - 0s - loss: 1.6602 - accuracy: 0.6044 - 143ms/epoch - 10ms/step
15/15 - 0s - loss: 1.7405 - accuracy: 0.5889 - 147ms/epoch - 10ms/step
15/15 - 0s - loss: 0.3786 - accuracy: 0.9067 - 134ms/epoch - 9ms/step
15/15 - 0s - loss: 0.3045 - accuracy: 0.9511 - 149ms/epoch - 10ms/step
15/15 - 0s - loss: 0.2813 - accuracy: 0.9422 - 144ms/epoch - 10ms/step
15/15 - 0s - loss: 2.2960 - accuracy: 0.2022 - 186ms/epoch - 12ms/step
15/15 - 0s - loss: 2.3080 - accuracy: 0.0822 - 147ms/epoch - 10ms/step
15/15 - 0s 