# Natural Computing - Project
#### Submission by group 25 (Chihab Amghane, Max Driessen, Jordy Naus)

The code below uses the [DEAP framework](https://github.com/deap/deap), which is an intuitive framework for evolutionary algorithms and genetic programming. We adapted several components of this framework to match more closely with the [WANN implementation](https://github.com/google/brain-tokyo-workshop/tree/master/WANNRelease).

## Imports

In [1]:
# DEAP
from deap import gp, base, tools, creator, algorithms

# Data processing and plotting
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Requirements for the algorithm
from operator import attrgetter
from functools import partial

# Standard python imports
import random, pickle, math, re, os
import numpy as np

# Magic for inline plots
%matplotlib inline

## Helper functions

In [2]:
def exp(x):
    return np.exp(np.clip(x, -float('inf'), 709.))

## Global parameters

In [3]:
# Dataset
DATASET = "MNIST" # choose from {"MNIST", "Fashion-MNIST"} 
N_CLASSES_TO_USE = 10 # at most 10

# Individual trees
P_INITIAL_CONNECTION = 0.05

# Fitness
SAMPLE_SIZE = 200
WEIGHTS_TO_TEST = [-2, -1, 1, 2] # -0.5 and 0.5 are not used due to long runtime

# Parent selection
TOURNAMENT_SIZE = 32

# Mutation (probabilities should sum to 1)
P_MUTATE_ACTIVATION = 0.5
P_ADD_NODE = 0.25
P_ADD_CONNECTION = 0.2
P_ENABLE_CONNECTION = 0.05

# Evolution
POPULATION_SIZE = 250
N_GENERATIONS = 1000
CULL_RATIO = 0.2
ELITE_RATIO = 0.2

# Filenames
RESULTS_FILENAME = f"DEAPWANN-{DATASET}-results.pkl"

## Loading preprocessed data

In [4]:
# Set the correct data filename
filename = f"{DATASET}-{N_CLASSES_TO_USE}.pkl"

# If the data has not yet been preprocessed in the specified way, do so now
if not os.path.exists(os.path.join("data", filename)):
    print("Preprocessed dataset does not exist yet, creating now.")
    os.system(f"python Preprocessing.py -d {DATASET} -c {N_CLASSES_TO_USE}")

# Load the preprocessed data
with open(os.path.join("data", filename), "rb") as f:
    (X_train, Y_train), (X_test, Y_test) = pickle.load(f)

## Defining operators

In [5]:
# Define aggregator (weighted sum)
def aggregate(w, args):
    return w*sum(args)

# Define operators (with a variable number of inputs)
def linear(w, *args):
    return aggregate(w, args)

def step(w, *args):
    return float(aggregate(w, args) > 0)

def sine(w, *args):
    return np.sin(np.pi*aggregate(w, args))

def gaussian(w, *args):
    return exp(-np.multiply(aggregate(w, args), aggregate(w, args))/2.0)

def tanh(w, *args):
    return np.tanh(aggregate(w, args))

def sigmoid(w, *args):
    return (np.tanh(aggregate(w, args)/2.0) + 1.0)/2.0

def inverse(w, *args):
    return -aggregate(w, args)

def absolute(w, *args):
    return abs(aggregate(w, args))

def relu(w, *args):
    return np.maximum(0.0, aggregate(w, args))

def cosine(w, *args):
    return np.cos(np.pi*aggregate(w, args))

In [6]:
# Define dictionary of functions for compiling
function_context = {'linear':linear, 'relu':relu, 'step':step, 'sine':sine, 'gaussian':gaussian, 'tanh':tanh,  
                    'sigmoid':sigmoid, 'inverse':inverse, 'absolute':absolute, 'cosine':cosine}

# Create lists of function and argument names
function_names = list(function_context.keys())
argument_names = [f"ARG{i}" for i in range(X_train.shape[1])]

## Defining individuals

##### Defining nodes

In [7]:
# Generic Node class
class Node:
    def __init__(self, name):
        # Each node has a name and a list of parents
        self.name = name
        self.parents = []
    
    def __str__(self):
        raise NotImplementedError("String function is only implemented for subclasses")

# Class for terminal nodes (inputs)
class TerminalNode(Node):
    def __init__(self, name):
        super().__init__(name)

    def __str__(self):
        # Terminal nodes are simply formatted as their name (e.g. "ARG42")
        return self.name

# Class for non-terminal nodes (hidden + outputs)
class NonterminalNode(Node):
    def __init__(self, name):
        # Non-terminal nodes also have lists of children and disabled children
        super().__init__(name)
        self.children = []
        self.disabled = []

    def __str__(self):
        # Non-terminal nodes are formatted as "name(child1, child2, ...)"
        return f"{self.name}(w, {', '.join([str(child) for child in self.children])})"

##### Defining individuals/multi-class trees

In [8]:
# Class for multi-output trees
class MultiClassTree:
    def __init__(self, n_inputs, n_outputs, p_initial_connection):
        # Initialize lists of input, output and internal nodes
        self.inputs = [TerminalNode(argument_names[i]) for i in range(n_inputs)]
        self.outputs = [NonterminalNode("linear") for _ in range(n_outputs)]
        self.hidden = []
        self.born = -1
        
        # Add initial connections
        self.n_connections = 0
        for output in self.outputs:
            # With a chance of P_INITIAL_CONNECTION, the connection is enabled, otherwhise it is disabled
            for child in self.inputs:
                if random.random() < P_INITIAL_CONNECTION:
                    output.children.append(child)
                else:
                    output.disabled.append(child)
                child.parents.append(output)
                
            # If an output has no enabled children, one of the children is enabled to make the tree valid
            if not output.children:
                child = random.choice(self.inputs)
                output.disabled.remove(child)
                output.children.append(child)
            
            # Update the number of enabled connections in the tree
            self.n_connections += len(output.children)

    def __str__(self):
        # Printing the tree only prints the number of hidden nodes and enabled connections
        return f"MultiClassTree with {len(self.hidden)} hidden nodes and {self.n_connections} connections"\
                + (f", born in generation {self.born}" if self.born >= 0 else "")
    
    def get_strings(self):
         # (Recursively) parsing output function strings, to parse the tree for evaluation
        try:
            return [str(output) for output in self.outputs]
        except RecursionError:
            print("Maximum recursion depth reached")
            return None

##### Initializing the DEAP toolbox

In [9]:
# Intialize the toolbox which will contain all sorts of functions for the genetic programming process
toolbox = base.Toolbox()

In [10]:
# Define classes for fitness and individuals (using DEAP's creator module)
creator.create("Fitness", base.Fitness, weights=(-1.0, -1.0))
creator.create("Individual", MultiClassTree, fitness=creator.Fitness, rank=-1) # An individual is a multi-class tree

In [11]:
# Define how to initialize an individual or population
toolbox.register("individual", creator.Individual, X_train.shape[1], N_CLASSES_TO_USE, P_INITIAL_CONNECTION)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)

## Fitness function

##### Compiling multi-class trees into functions

In [12]:
# Compiling a tree into a function
def compile_multiclasstree(tree):
    # Parse trees to strings for all outputs
    strings = tree.get_strings()
    
    # Convert the string to lambda functions using eval() and the proper function context
    funcs = [eval(f"lambda w, {', '.join(argument_names)}: {string}", function_context, {}) for string in strings]
    
    # Create the function, which applies softmax over the outputs of the created lambda functions
    def func(w, args):
        def _softmax(x):
            return exp(x)/np.sum(exp(x), axis=0)
        return _softmax([f(w, *args) for f in funcs])
    
    # Return the created function
    return func

In [13]:
# Add the compile function to toolbox
toolbox.register("compile", compile_multiclasstree)

##### Defining the fitness function

In [14]:
def get_sample(seed, sample_size):
     # Ensure all individuals in a generation can be tested on the same samples
    if seed >=0:
        np.random.seed(seed)
        
    # Create a list of indices of samples to test, ensuring an equal number of samples from each class
    sample_indices = []
    samples_per_class = int(sample_size/N_CLASSES_TO_USE)
    for c in range(N_CLASSES_TO_USE):
        c_indices = np.where(Y_train == c)[0]
        assert len(c_indices) >= samples_per_class, \
            f"Class {c} has too few elements to reach the desired number of evaluation samples"
        sample_indices.extend(np.random.permutation(c_indices)[:samples_per_class])
    
    # Return the list of sample indices
    return sample_indices

In [15]:
# Define fitness function (cross-entropy loss & inversed number of connections)
def fitness(individual, sample_indices, weights_to_test):
    # Compile the functions corresponding to the individual
    func = toolbox.compile(individual)
    
    # Define how to compute cross-entropy
    def _cross_entropy(pred, label):
        return -np.log(pred[label])
    
    # Compute cross-entropy loss for each of the samples
    results = []
    for w in weights_to_test:
        WEIGHT=w
        w_results = [_cross_entropy(func(w, X_train[i]), Y_train[i]) for i in sample_indices]
        results.append(np.average(w_results))
    
    # Return the average and best cross-entropy loss
    return (np.average(results), np.min(results))

In [16]:
# Add the sample generator and fitness function to the toolbox
toolbox.register("get_sample", get_sample, sample_size=SAMPLE_SIZE)
toolbox.register("evaluate", fitness, weights_to_test=WEIGHTS_TO_TEST)

## Evolution components

##### Parent selection

In [17]:
# Define how to select parents (tournament selection based on NSGA2 rank)
toolbox.register("select", tools.selTournament, tournsize=TOURNAMENT_SIZE, fit_attr="rank")

##### Mutation

In [18]:
# Mutating the activation function of a hidden node
def _mutate_activation(tree):
    # If the tree has no hidden nodes, do nothing
    if not tree.hidden:
        return
        
    # Select a random node and give it a different activation function
    node = random.choice(tree.hidden)
    node.name = random.choice([name for name in function_names if not name == node.name])


# Adding a node to the tree
def _add_node(tree):
    # Select a random parent-child pair between which to place a node
    parent = random.choice(tree.outputs + tree.hidden)
    child = random.choice(parent.children)
        
    # Disable the connection between the parent and the child
    parent.children.remove(child)
    parent.disabled.append(child)

    # Create a new node with a random activation function and add it to the tree
    new_node = NonterminalNode(random.choice(function_names))
    tree.hidden.append(new_node)
        
    # Update the parent/child relations
    parent.children.append(new_node)
    new_node.parents.append(parent)
    new_node.children.append(child)
    child.parents.append(new_node)
        
    # Update the number of enabled connections
    tree.n_connections += 1

    
# Adding a connection in the tree
def _add_connection(tree):
    # Function that checks if node1 is an ancestor of node 2
    def _is_ancestor(node1, node2):
        return node1 in [node2] + node2.parents or any([_is_ancestor(node1, node3) for node3 in node2.parents])
    
    # Find a valid parent-child pair for a connection, respecting the feed-forward property (no loops)
    valid_connection = False
    n_attempts = 0
    while not valid_connection and n_attempts < 500:
        parent = random.choice(tree.outputs + tree.hidden)
        child = random.choice(tree.inputs + tree.hidden)
        valid_connection = not _is_ancestor(child, parent) # Connection is valid if child is not an ancestor of parent
        n_attempts += 1
    
    # If a valid connection was found, update the parent/child relations and the number of enabled connections
    if n_attempts < 500:
        child.parents.append(parent)
        parent.children.append(child)
        tree.n_connections += 1


# Enable a disabled connection (created during initialization or when adding a node)
def _enable_connection(tree):
    # Check if there are any nodes with disabled connections; if not, do nothing
    parents_with_disabled = [node for node in tree.outputs + tree.hidden if node.disabled]
    if not parents_with_disabled:
        return
    
    # Select a random disabled parent-child pair
    parent = random.choice(parents_with_disabled)
    child = random.choice(parent.disabled)
    
    # Enable the corresponding connection and update the number of enabled connections
    parent.disabled.remove(child)
    parent.children.append(child)
    tree.n_connections += 1

In [19]:
def mutate(tree):
    # Copy the parent tree
    tree = toolbox.clone(tree)
        
    # Create lists of the various mutation functions and the corresponding probabilities
    mutation_functions = [_mutate_activation, _add_node, _add_connection, _enable_connection]
    probabilities = [P_MUTATE_ACTIVATION, P_ADD_NODE, P_ADD_CONNECTION, P_ENABLE_CONNECTION]
    
    # Ensure probabilities sum to 1
    assert sum(probabilities) == 1, "Mutation probabilities should sum to 1"
    
    # Choose a mutation function using the provided probabilities and execute it
    mutation_function, = random.choices(mutation_functions, probabilities, k=1)
    mutation_function(tree)
    
    # Return the resulting tree
    return tree,

In [20]:
# Add the mutate function to the toolbox
toolbox.register("mutate", mutate)

## Defining statistics

In [21]:
# Describe which kinds of statistics to keep track of
stats_avgfit = tools.Statistics(key = lambda ind: ind.fitness.values[0])
stats_bestfit = tools.Statistics(key = lambda ind: ind.fitness.values[1])
stats_connections = tools.Statistics(key = lambda ind: ind.n_connections)
stats_hidden = tools.Statistics(key = lambda ind: len(ind.hidden))

# Combine statistics into a single multistatistics object
mstats = tools.MultiStatistics(avg_fitness=stats_avgfit, best_fitness=stats_bestfit, 
                               hidden=stats_hidden, connections=stats_connections)

In [22]:
# Describe metrics to keep track of for each statistic
mstats.register("avg", np.mean)
mstats.register("std", np.std)
mstats.register("min", np.min)

## Running the genetic programming algorithm

##### Defining the algorithm

In [23]:
def eaWann(population, toolbox, ngen, cull_ratio, elite_ratio, stats=None, halloffame=None, verbose=True):
    """
    Evolutionary algorithm for Weight Agnostic Neural Networks (WANNs)
    Based on the algorithms provided by DEAP, as well as the WANN implementation
    
    The basic idea is as follows:
    1. Sort the population based on fitness (using NSGA2)
    2. Remove the worst individuals
    3. Copy the best individuals directly to the next generation
    4. Perform tournament selection to create the remaining offspring
    5. Evaluate all individuals in the new population (each individual is tested on the same sample)
    6. Repeat from 1
    
    Parameters:
    population: the intial population
    toolbox: the DEAP toolbox containing functions for parent selection, mutation etc.
    ngen: number of generations to run the algorithm for
    cull_ratio: fraction of the population that will be thrown away every generation (worst individuals)
    elite_ratio: fraction of the population that will be directly copied to the next generation (best individuals)
    stats: (Multi)Statistics object, keeping track of evolution statistics
    halloffame: List containing the best individuals that ever lived
    verbose: whether or not to print statistics
    
    """
    
    # Initialize logbook and set the correct headers
    logbook = tools.Logbook()
    logbook.header = ['gen'] + ['avg_fitness', 'best_fitness', 'connections', 'hidden'] if stats else []
    for field in stats.fields:
        if "fitness" in field:
            logbook.chapters[field].header = "min", "avg", "std"
        else:
            logbook.chapters[field].header = "avg", "std"

    # Evaluate all individuals using the same sample
    sample = toolbox.get_sample(0)
    fitnesses = toolbox.map(partial(toolbox.evaluate, sample_indices=sample), population)
    for ind, fit in zip(population, fitnesses):
        ind.fitness.values = fit
        ind.born = 0

    # Update hall of fame
    if halloffame is not None:
        halloffame.update(population)
        
    # Record and print performance if applicable
    record = stats.compile(population) if stats else {}
    logbook.record(gen=0, **record)
    if verbose:
        print(logbook.stream)

    # Begin the generational process
    for gen in range(1, ngen + 1): 
        
        # Initialize offspring and determine offspring size
        offspring = []
        population_size = len(population)
        
        # Rank the population and update rank values (tournament selection prefers individuals with bigger rank)
        ranked_population = tools.selNSGA2(population, population_size)
        for ind, rank in zip(ranked_population, reversed(range(population_size))):
            ind.rank = rank
        
        # Culling - remove worst performing individuals
        number_to_cull = int(cull_ratio*population_size)
        ranked_population = ranked_population[:population_size-number_to_cull]
        
        # Elitism - select and copy best performing individuals
        number_of_elites = int(elite_ratio*population_size)
        for i in range(number_of_elites):
            copy = toolbox.clone(ranked_population[i])
            del copy.fitness.values # Will be re-evaluated using this generation's sample
            offspring.append(copy)
            
        # Compute number of offspring that still need to be generated
        offspring_to_generate = population_size - number_of_elites
            
        # Select parents via (NSGA2 rank-based) tournament selection
        parents = toolbox.select(ranked_population, offspring_to_generate)
        
        # Mutate parents to obtain children
        for parent in parents:
            child, = toolbox.mutate(parent)
            del child.fitness.values
            child.born = gen
            offspring.append(child)

        # Evaluate all individuals in the offspring using the same sample
        sample = toolbox.get_sample(gen)
        fitnesses = toolbox.map(partial(toolbox.evaluate, sample_indices=sample), offspring)
        for ind, fit in zip(offspring, fitnesses):
            ind.fitness.values = fit

        # Update the hall of fame with the generated individuals
        if halloffame is not None:
            halloffame.update(offspring)

        # Replace the current population by the offspring
        population[:] = offspring

        # Append the current generation statistics to the logbook
        record = stats.compile(population) if stats else {}
        logbook.record(gen=gen, **record)
        if verbose:
            print(logbook.stream)

    return population, logbook

##### Running the algorithm

In [None]:
# Run the evolutionary algorithm
pop = toolbox.population(POPULATION_SIZE)
hof = tools.HallOfFame(10)
pop, log = eaWann(population=pop, toolbox=toolbox, ngen=N_GENERATIONS, cull_ratio=CULL_RATIO, 
                  elite_ratio=ELITE_RATIO, stats=mstats, halloffame=hof, verbose=True)

   	      avg_fitness      	      best_fitness      	  connections  	   hidden  
   	-----------------------	------------------------	---------------	-----------
gen	min    	avg    	std    	min    	avg   	std     	avg    	std    	avg	std
0  	2.79939	3.35861	0.27448	2.21612	2.6656	0.157292	128.352	10.3207	0  	0  
1  	2.80176	2.97141	0.128445	2.20243	2.39618	0.0959565	120.104	7.35535	0.164	0.370276
2  	2.75592	2.81704	0.0336058	2.24304	2.47257	0.0773412	115.48 	4.03306	0.192	0.393873
3  	2.7585 	2.83913	0.0370141	2.18018	2.28419	0.122425 	122.456	4.28673	0.692	0.655085
4  	2.71819	2.81537	0.0625957	2.14902	2.25897	0.10826  	125.388	2.3642 	1.34 	0.769675
5  	2.72317	2.80961	0.0518831	2.14181	2.33076	0.142228 	126.088	2.96652	1.716	0.855187
6  	2.7025 	2.80545	0.0721414	2.08927	2.18626	0.0801644	127.212	0.773987	2.3  	0.956033
7  	2.71571	2.81284	0.0554451	2.10404	2.15665	0.036581 	127.672	0.777442	2.728	0.77847 
8  	2.72419	2.83397	0.041427 	2.13665	2.20326	0.0317034	128.632	0.83939 	2.7

90 	2.12092	2.39349	0.214403 	1.38985	1.52497	0.104625 	171.476	3.39786 	17.66 	5.24751 
91 	2.24248	2.35919	0.121101 	1.60548	1.80215	0.0919348	175.672	3.59617 	22.184	5.08076 
92 	2.22863	2.43891	0.13828  	1.51117	1.66072	0.0909391	171.624	2.78184 	15.62 	3.65781 
93 	2.29362	2.47175	0.224123 	1.41689	1.55359	0.139696 	172.184	2.62567 	15.352	3.30395 
94 	2.25928	2.37097	0.109367 	1.33889	1.57186	0.148435 	172.564	3.46611 	16.016	4.12113 
95 	2.14784	2.30174	0.163644 	1.31469	1.51868	0.12068  	175.152	4.16472 	19.128	5.52228 
96 	2.17685	2.40201	0.164641 	1.47764	1.6337 	0.110446 	173.576	2.82351 	16.012	3.53664 
97 	2.22729	2.36479	0.195308 	1.4708 	1.72606	0.119709 	177.428	3.87876 	20.412	5.90815 
98 	2.19117	2.32075	0.158509 	1.42881	1.65124	0.138151 	177.016	4.16362 	20.328	5.63244 
99 	2.25558	2.38526	0.206905 	1.35486	1.62839	0.140698 	178.768	4.66928 	22.776	6.1121  
100	2.19781	2.37308	0.116513 	1.37971	1.5933 	0.117257 	177.392	3.6175  	19.472	5.40825 
101	2.08286	2.34002	0

183	1.94264	2.03349	0.148909 	1.24082	1.39416	0.0946269	221.44 	1.18929 	41.936	0.797436
184	1.95097	2.10943	0.262665 	1.06849	1.21705	0.139937 	221.168	1.22138 	41.312	1.0073  
185	1.96944	2.1241 	0.267867 	1.269  	1.35407	0.100335 	222.112	1.3128  	41.808	1.23091 
186	1.90108	2.05232	0.222275 	1.15152	1.2693 	0.105888 	222.58 	0.989747	42.356	1.43571 
187	1.98189	2.1771 	0.297082 	1.25669	1.37818	0.113008 	223.364	0.857615	41.376	1.53448 
188	1.9256 	2.10216	0.191439 	1.18234	1.31997	0.106403 	223.688	0.871009	41.484	1.54459 
189	1.91346	2.05095	0.216219 	1.12545	1.27236	0.113831 	224.632	1.01616 	42.176	1.032   
190	2.05447	2.21049	0.120142 	1.2286 	1.3177 	0.0767479	224.672	1.30706 	41.728	1.27044 
191	1.98211	2.18926	0.142107 	1.23693	1.31817	0.0901665	223.648	1.26653 	41.132	1.18092 
192	2.06539	2.18625	0.138593 	1.20493	1.34428	0.0985025	226.228	1.51262 	43.036	1.50556 
193	1.92925	2.08988	0.144187 	1.07471	1.17149	0.109763 	224.86 	1.82658 	41.936	1.63826 
194	1.94938	2.10135	0

275	1.72162	1.8379 	0.406045 	0.880293	1.0016  	0.165904 	268.236	0.836842	56.68 	0.79599 
276	1.7569 	1.87282	0.193999 	1.04261 	1.16628 	0.126718 	268.512	0.89992 	56.72 	1.09252 
277	1.66503	1.84003	0.385149 	0.836228	0.967027	0.182432 	269.7  	0.972625	57.236	1.58629 
278	1.65391	1.78137	0.208912 	0.923684	1.02639 	0.127373 	270.148	1.32593 	58.456	1.89738 
279	1.7582 	1.92714	0.213409 	1.00701 	1.13346 	0.141942 	270.564	1.10901 	57.436	1.19579 
280	1.73839	1.84584	0.302479 	1.06695 	1.17617 	0.111882 	271.132	1.6008  	59.452	1.58483 
281	1.71665	1.85325	0.302127 	0.901523	1.06934 	0.158178 	271.544	1.21658 	59.196	1.6702  
282	1.68478	1.77668	0.116477 	0.956938	1.06844 	0.109197 	271.8  	1.35056 	59.052	1.43433 
283	1.71671	1.79546	0.207158 	0.99205 	1.0891  	0.123409 	273.852	1.06118 	60.62 	1.0824  
284	1.68875	1.77449	0.201377 	0.93098 	1.02589 	0.107061 	274.572	1.19198 	60.26 	1.0735  
285	1.69047	1.76276	0.162631 	0.955121	1.04827 	0.106979 	276.016	1.31139 	61.116	1.04237 

366	1.60847	1.81715	0.745879 	0.848115	0.94477 	0.151528 	316.84 	0.945727	76.212	0.924692
367	1.61241	1.76821	0.330264 	0.828175	0.918347	0.105505 	318.032	1.09132 	76.208	0.927759
368	1.52822	1.6987 	0.577236 	0.760919	0.903994	0.202725 	317.656	0.960033	75.9  	1.01686 
369	1.61019	1.82445	0.393837 	0.800216	0.870243	0.106979 	319.176	1.13535 	77.552	1.39975 
370	1.55305	1.69776	0.547983 	0.789386	0.901105	0.182234 	318.012	0.969462	75.804	0.941055
371	1.62717	1.73721	0.333066 	0.856361	0.926478	0.0917438	318.88 	0.780769	76.236	1.40723 
372	1.5533 	1.68738	0.275386 	0.77524 	0.858697	0.103131 	319.316	0.929593	76.664	1.39682 
373	1.57183	1.72323	0.495555 	0.784869	0.889181	0.154884 	319.572	0.837148	76.196	0.702555
374	1.42596	1.59456	0.412438 	0.635392	0.714442	0.119984 	319.7  	0.926283	77.156	0.787187
375	1.54733	1.73799	0.709636 	0.781518	0.881954	0.147405 	320.336	0.950318	77.56 	1.06508 
376	1.59147	1.72581	0.514674 	0.85414 	0.932317	0.120663 	320.98 	0.84119 	77.048	0.875041

457	1.44548	1.60073	0.655914 	0.821418	0.887815	0.116349 	352.144	0.807009	92.344	1.1107  
458	1.61595	1.69186	0.137137 	0.910916	0.9745  	0.0894477	352.26 	0.946784	92    	1.43666 
459	1.62857	1.74971	0.514064 	0.91189 	0.990513	0.105085 	353.484	0.835311	93.016	0.838895
460	1.55552	1.74627	1.12359  	0.883182	0.957898	0.120395 	353.38 	1.12942 	92.46 	1.09563 
461	1.58625	1.66466	0.190243 	0.928187	0.995529	0.0972965	354.304	1.32498 	93.032	0.752978
462	1.53653	1.65973	0.790818 	0.839629	0.908576	0.100448 	356.4  	1.39427 	93.536	0.681692
463	1.5796 	1.65624	0.196066 	0.941632	1.01495 	0.116974 	357.344	0.899813	93.72 	0.705408
464	1.47891	1.65111	0.973528 	0.896651	0.989198	0.168592 	358.444	1.12733 	94.268	0.841532
465	1.47262	1.62225	0.4581   	0.793552	0.875927	0.123246 	359.012	0.846083	94.628	0.790959
466	1.45442	1.5758 	0.400697 	0.758265	0.82905 	0.110562 	359.484	0.939012	94.884	0.928732
467	1.4477 	1.52461	0.139303 	0.77784 	0.859489	0.106906 	359.912	0.96346 	94.704	0.81999 

547	1.40183	1.51336	0.446089 	0.752497	0.801183	0.0918732	399.888	0.86917 	116.148	1.14808 
548	1.40292	1.47836	0.177635 	0.804125	0.897935	0.0851869	400.108	0.748556	115.824	0.992484
549	1.43856	1.53643	0.292429 	0.88904 	0.994217	0.0921369	400.424	0.832   	116.34 	1.2963  
550	1.40928	1.63898	1.85911  	0.804164	0.876621	0.0733579	400.568	0.936684	116.664	1.56688 
551	1.26757	1.55304	1.91922  	0.649111	0.726366	0.107222 	401.384	0.969816	117.164	1.22683 
552	1.52147	1.59229	0.127154 	0.907284	0.99155 	0.0840229	401.232	1.03256 	117.056	1.19535 
553	1.40612	1.48751	0.166173 	0.855418	0.940599	0.100254 	401.572	0.957505	117.044	1.20419 
554	1.45095	1.66971	1.61388  	0.848785	0.949357	0.0951788	402.024	1.20309 	117.656	1.08888 
555	1.43626	1.64946	0.963929 	0.81993 	0.921022	0.189476 	402.636	1.12406 	117.632	1.17328 
556	1.37516	1.43553	0.14433  	0.76961 	0.82393 	0.0911399	404.048	1.1652  	118.048	1.21561 
557	1.4008 	1.49014	0.415771 	0.75683 	0.817906	0.0712534	404.46 	0.907965	118.2

##### Wrapping up

In [None]:
# Store results
with open(RESULTS_FILENAME, "wb") as f:
    pickle.dump((pop, log, hof), f)

## Plotting statistics

In [None]:
# Extract generation IDs, minimum fitnesses and average total heights per generation
gen = log.select("gen")
fitness_best = log.chapters["avg_fitness"].select("min") 
conn_avg = log.chapters["connections"].select("avg")

In [None]:
# Plot line for minimum fitness
fig, fit_ax = plt.subplots()
fit_line = fit_ax.plot(gen, fitness_best, "b-", label="Best Fitness")
fit_ax.set_xlabel("Generation")
fit_ax.set_ylabel(f"Cross-entropy Loss", color="b")
for tl in fit_ax.get_yticklabels():
    tl.set_color("b")

# Plot line for average total height
height_ax = fit_ax.twinx()
height_line = height_ax.plot(gen, conn_avg, "r-", label="Average Number of Connections")
height_ax.set_ylabel("# Connections", color="r")
for tl in height_ax.get_yticklabels():
    tl.set_color("r")

# Add legend
lines = fit_line + height_line
labs = [l.get_label() for l in lines]
fit_ax.legend(lines, labs, bbox_to_anchor=(0.8, -0.15))

# Show the result
plt.show()

## Inspecting the best individual

In [None]:
best_ind = hof[0]
print(best_ind)
print(f"Fitness of best individual: {best_ind.fitness}")

##### Printing trees

In [None]:
# Print the trees of the best individual
for string in best_ind.get_strings():
    print(f"{string}\n")

##### Computing training & validation accuracy

In [None]:
# Retrieving predictions from an individual
def get_predictions(individual, X, weight):
    func = compile_multiclasstree(individual)
    predictions = []
    for i in range(X.shape[0]):
        outputs_i = func(weight, X[i])
        predictions.append(np.argmax(outputs_i))
    return predictions

In [None]:
# Compute accuracy of predictions
def compute_accuracy(Y_pred, Y_true):
    n_correct = np.sum(Y_pred == Y_true)
    return n_correct/Y_true.shape[0]

In [None]:
# Retrieve predictions of the best individual on the training and validation sets, for all weights
Y_train_pred = np.array([get_predictions(best_ind, X_train, w) for w in WEIGHTS_TO_TEST])
Y_test_pred = np.array([get_predictions(best_ind, X_test, w) for w in WEIGHTS_TO_TEST])

In [None]:
# Compute training accuracies for all weights and extract the best weight
train_accs = [compute_accuracy(Y_train_pred[i], Y_train) for i in range(len(WEIGHTS_TO_TEST))]
test_accs = [compute_accuracy(Y_test_pred[i], Y_test) for i in range(len(WEIGHTS_TO_TEST))]
best_weight_idx_train = np.argmax(train_accs)
best_weight_idx_test = np.argmax(test_accs)

# Print best training and validation accuracies of the best individual
print(f"Best training accuracy (weight {WEIGHTS_TO_TEST[best_weight_idx_train]}): {np.max(train_accs)}")
print(f"Best validation accuracy (weight {WEIGHTS_TO_TEST[best_weight_idx_test]}): {np.max(test_accs)}")

In [None]:
# Obtain majority votes (in case of a tie, use vote of the best classifier)
def get_majority_predictions(predictions):
    def _majority(l):
        return max(set(l), key=l.count)
    predictions = [_majority(list(predictions[:,i])) for i in range(predictions.shape[1])]
    return predictions

In [None]:
# Compute majority predictions for train and validation sets
Y_train_majpred = get_majority_predictions(Y_train_pred)
Y_test_majpred = get_majority_predictions(Y_test_pred)

# Print the accuracies of the majority votes
print(f"Majority training accuracy: {compute_accuracy(Y_train_majpred, Y_train)}")
print(f"Majority validation accuracy: {compute_accuracy(Y_test_majpred, Y_test)}")

##### Confusion matrices

In [None]:
# Compute confusion matrices
cm_train_best = confusion_matrix(Y_train, Y_train_pred[best_weight_idx_train], labels=range(N_CLASSES_TO_USE))
cm_test_best = confusion_matrix(Y_test, Y_test_pred[best_weight_idx_test], labels=range(N_CLASSES_TO_USE))
cm_train_maj = confusion_matrix(Y_train, Y_train_majpred, labels=range(N_CLASSES_TO_USE))
cm_test_maj = confusion_matrix(Y_test, Y_test_majpred, labels=range(N_CLASSES_TO_USE))

In [None]:
# Create figure
fig, ax = plt.subplots(2,2, figsize=(11,10))
ax = ax.ravel()

# Plot confusion matrix for training data using best weight
sns.heatmap(cm_train_best, annot=True, fmt='g', ax=ax[0], cmap="Blues")
ax[0].set_xlabel('Predicted labels')
ax[0].set_ylabel('True labels')
ax[0].set_title('Confusion matrix for training data (best weight)')

# Plot confusion matrix for validation data using best weight
sns.heatmap(cm_test_best, annot=True, fmt='g', ax=ax[1], cmap="Blues")
ax[1].set_xlabel('Predicted labels')
ax[1].set_ylabel('True labels')
ax[1].set_title('Confusion matrix for validation data (best weight)')

# Plot confusion matrix for training data using majority vote
sns.heatmap(cm_train_maj, annot=True, fmt='g', ax=ax[2], cmap="Blues")
ax[2].set_xlabel('Predicted labels')
ax[2].set_ylabel('True labels')
ax[2].set_title('Confusion matrix for training data (majority vote)')

# Plot confusion matrix for validation data using majority vote
sns.heatmap(cm_test_maj, annot=True, fmt='g', ax=ax[3], cmap="Blues")
ax[3].set_xlabel('Predicted labels')
ax[3].set_ylabel('True labels')
ax[3].set_title('Confusion matrix for validation data (majority vote)')

# Show the result
plt.show()

##### Used features

In [None]:
# Extract input tallies from best individual
input_tallies = []
for tree in best_ind.get_strings():
    inputs_used = list(map(int, re.findall("[0-9]+", tree)))
    input_tally = np.zeros(X_train.shape[1])
    for arg in inputs_used:
        input_tally[arg] += 1
    input_tallies.append(input_tally)

# Create plots of the inputs (pixels) used in the tree of the best individual for each class
# Note: for the best-looking plot, this implementation assumes that N_CLASSES_TO_USE is set to 10
fig, ax = plt.subplots(2, 5, figsize=(16,5))
ax = ax.ravel()
for i, tally in enumerate(input_tallies):
    img_shape = int(math.sqrt(X_train.shape[1]))
    ax[i].imshow(np.array(tally).reshape(img_shape, img_shape), clim=(0,np.max(input_tallies)))
    ax[i].axis("off")
    ax[i].set_title(f"Pixels used for class {i}")
plt.show()

In [None]:
# Plot image of all inputs used in the tree
plt.imshow(np.sum(input_tallies, axis=0).reshape(16,16))
plt.axis("off")
plt.title("Pixels used in the entire tree")
plt.show()