# Natural Computing - Project
#### Submission by group 25 (Chihab Amghane, Max Driessen, Jordy Naus)

The code below uses the [DEAP framework](https://github.com/deap/deap), which is an intuitive framework for evolutionary algorithms and genetic programming. We adapted several components of this framework to match more closely with the [WANN implementation](https://github.com/google/brain-tokyo-workshop/tree/master/WANNRelease).

## Imports

In [1]:
# DEAP
from deap import gp, base, tools, creator, algorithms

# Data processing and plotting
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Requirements for the algorithm
from operator import attrgetter
from functools import partial

# Standard python imports
import random, pickle, math, re, os
import numpy as np

# Magic for inline plots
%matplotlib inline

## Helper functions

In [2]:
def exp(x):
    return np.exp(np.clip(x, -float('inf'), 709.))

## Global parameters

In [3]:
# Dataset
DATASET = "Fashion-MNIST" # choose from {"MNIST", "Fashion-MNIST"} 
N_CLASSES_TO_USE = 10 # at most 10

# Individual trees
P_INITIAL_CONNECTION = 0.05

# Fitness
SAMPLE_SIZE = 200
WEIGHTS_TO_TEST = [-2, -1, 1, 2] # -0.5 and 0.5 are not used due to long runtime

# Parent selection
TOURNAMENT_SIZE = 32

# Mutation (probabilities should sum to 1)
P_MUTATE_ACTIVATION = 0.5
P_ADD_NODE = 0.25
P_ADD_CONNECTION = 0.2
P_ENABLE_CONNECTION = 0.05

# Evolution
POPULATION_SIZE = 250
N_GENERATIONS = 1000
CULL_RATIO = 0.2
ELITE_RATIO = 0.2

# Filenames
RESULTS_FILENAME = f"DEAPWANN-{DATASET}-results.pkl"

## Loading preprocessed data

In [4]:
# Set the correct data filename
filename = f"{DATASET}-{N_CLASSES_TO_USE}.pkl"

# If the data has not yet been preprocessed in the specified way, do so now
if not os.path.exists(os.path.join("data", filename)):
    print("Preprocessed dataset does not exist yet, creating now.")
    os.system(f"python Preprocessing.py -d {DATASET} -c {N_CLASSES_TO_USE}")

# Load the preprocessed data
with open(os.path.join("data", filename), "rb") as f:
    (X_train, Y_train), (X_test, Y_test) = pickle.load(f)

## Defining operators

In [5]:
# Define aggregator (weighted sum)
def aggregate(w, args):
    return w*sum(args)

# Define operators (with a variable number of inputs)
def linear(w, *args):
    return aggregate(w, args)

def step(w, *args):
    return float(aggregate(w, args) > 0)

def sine(w, *args):
    return np.sin(np.pi*aggregate(w, args))

def gaussian(w, *args):
    return exp(-np.multiply(aggregate(w, args), aggregate(w, args))/2.0)

def tanh(w, *args):
    return np.tanh(aggregate(w, args))

def sigmoid(w, *args):
    return (np.tanh(aggregate(w, args)/2.0) + 1.0)/2.0

def inverse(w, *args):
    return -aggregate(w, args)

def absolute(w, *args):
    return abs(aggregate(w, args))

def relu(w, *args):
    return np.maximum(0.0, aggregate(w, args))

def cosine(w, *args):
    return np.cos(np.pi*aggregate(w, args))

In [6]:
# Define dictionary of functions for compiling
function_context = {'linear':linear, 'relu':relu, 'step':step, 'sine':sine, 'gaussian':gaussian, 'tanh':tanh,  
                    'sigmoid':sigmoid, 'inverse':inverse, 'absolute':absolute, 'cosine':cosine}

# Create lists of function and argument names
function_names = list(function_context.keys())
argument_names = [f"ARG{i}" for i in range(X_train.shape[1])]

## Defining individuals

##### Defining nodes

In [7]:
# Generic Node class
class Node:
    def __init__(self, name):
        # Each node has a name and a list of parents
        self.name = name
        self.parents = []
    
    def __str__(self):
        raise NotImplementedError("String function is only implemented for subclasses")

# Class for terminal nodes (inputs)
class TerminalNode(Node):
    def __init__(self, name):
        super().__init__(name)

    def __str__(self):
        # Terminal nodes are simply formatted as their name (e.g. "ARG42")
        return self.name

# Class for non-terminal nodes (hidden + outputs)
class NonterminalNode(Node):
    def __init__(self, name):
        # Non-terminal nodes also have lists of children and disabled children
        super().__init__(name)
        self.children = []
        self.disabled = []

    def __str__(self):
        # Non-terminal nodes are formatted as "name(child1, child2, ...)"
        return f"{self.name}(w, {', '.join([str(child) for child in self.children])})"

##### Defining individuals/multi-class trees

In [8]:
# Class for multi-output trees
class MultiClassTree:
    def __init__(self, n_inputs, n_outputs, p_initial_connection):
        # Initialize lists of input, output and internal nodes
        self.inputs = [TerminalNode(argument_names[i]) for i in range(n_inputs)]
        self.outputs = [NonterminalNode("linear") for _ in range(n_outputs)]
        self.hidden = []
        self.born = -1
        
        # Add initial connections
        self.n_connections = 0
        for output in self.outputs:
            # With a chance of P_INITIAL_CONNECTION, the connection is enabled, otherwhise it is disabled
            for child in self.inputs:
                if random.random() < P_INITIAL_CONNECTION:
                    output.children.append(child)
                else:
                    output.disabled.append(child)
                child.parents.append(output)
                
            # If an output has no enabled children, one of the children is enabled to make the tree valid
            if not output.children:
                child = random.choice(self.inputs)
                output.disabled.remove(child)
                output.children.append(child)
            
            # Update the number of enabled connections in the tree
            self.n_connections += len(output.children)

    def __str__(self):
        # Printing the tree only prints the number of hidden nodes and enabled connections
        return f"MultiClassTree with {len(self.hidden)} hidden nodes and {self.n_connections} connections"\
                + (f", born in generation {self.born}" if self.born >= 0 else "")
    
    def get_strings(self):
         # (Recursively) parsing output function strings, to parse the tree for evaluation
        try:
            return [str(output) for output in self.outputs]
        except RecursionError:
            print("Maximum recursion depth reached")
            return None

##### Initializing the DEAP toolbox

In [9]:
# Intialize the toolbox which will contain all sorts of functions for the genetic programming process
toolbox = base.Toolbox()

In [10]:
# Define classes for fitness and individuals (using DEAP's creator module)
creator.create("Fitness", base.Fitness, weights=(-1.0, -1.0))
creator.create("Individual", MultiClassTree, fitness=creator.Fitness, rank=-1) # An individual is a multi-class tree

In [11]:
# Define how to initialize an individual or population
toolbox.register("individual", creator.Individual, X_train.shape[1], N_CLASSES_TO_USE, P_INITIAL_CONNECTION)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)

## Fitness function

##### Compiling multi-class trees into functions

In [12]:
# Compiling a tree into a function
def compile_multiclasstree(tree):
    # Parse trees to strings for all outputs
    strings = tree.get_strings()
    
    # Convert the string to lambda functions using eval() and the proper function context
    funcs = [eval(f"lambda w, {', '.join(argument_names)}: {string}", function_context, {}) for string in strings]
    
    # Create the function, which applies softmax over the outputs of the created lambda functions
    def func(w, args):
        def _softmax(x):
            return exp(x)/np.sum(exp(x), axis=0)
        return _softmax([f(w, *args) for f in funcs])
    
    # Return the created function
    return func

In [13]:
# Add the compile function to toolbox
toolbox.register("compile", compile_multiclasstree)

##### Defining the fitness function

In [14]:
def get_sample(seed, sample_size):
     # Ensure all individuals in a generation can be tested on the same samples
    if seed >=0:
        np.random.seed(seed)
        
    # Create a list of indices of samples to test, ensuring an equal number of samples from each class
    sample_indices = []
    samples_per_class = int(sample_size/N_CLASSES_TO_USE)
    for c in range(N_CLASSES_TO_USE):
        c_indices = np.where(Y_train == c)[0]
        assert len(c_indices) >= samples_per_class, \
            f"Class {c} has too few elements to reach the desired number of evaluation samples"
        sample_indices.extend(np.random.permutation(c_indices)[:samples_per_class])
    
    # Return the list of sample indices
    return sample_indices

In [15]:
# Define fitness function (cross-entropy loss & inversed number of connections)
def fitness(individual, sample_indices, weights_to_test):
    # Compile the functions corresponding to the individual
    func = toolbox.compile(individual)
    
    # Define how to compute cross-entropy
    def _cross_entropy(pred, label):
        return -np.log(pred[label])
    
    # Compute cross-entropy loss for each of the samples
    results = []
    for w in weights_to_test:
        WEIGHT=w
        w_results = [_cross_entropy(func(w, X_train[i]), Y_train[i]) for i in sample_indices]
        results.append(np.average(w_results))
    
    # Return the average and best cross-entropy loss
    return (np.average(results), np.min(results))

In [16]:
# Add the sample generator and fitness function to the toolbox
toolbox.register("get_sample", get_sample, sample_size=SAMPLE_SIZE)
toolbox.register("evaluate", fitness, weights_to_test=WEIGHTS_TO_TEST)

## Evolution components

##### Parent selection

In [17]:
# Define how to select parents (tournament selection based on NSGA2 rank)
toolbox.register("select", tools.selTournament, tournsize=TOURNAMENT_SIZE, fit_attr="rank")

##### Mutation

In [18]:
# Mutating the activation function of a hidden node
def _mutate_activation(tree):
    # If the tree has no hidden nodes, do nothing
    if not tree.hidden:
        return
        
    # Select a random node and give it a different activation function
    node = random.choice(tree.hidden)
    node.name = random.choice([name for name in function_names if not name == node.name])


# Adding a node to the tree
def _add_node(tree):
    # Select a random parent-child pair between which to place a node
    parent = random.choice(tree.outputs + tree.hidden)
    child = random.choice(parent.children)
        
    # Disable the connection between the parent and the child
    parent.children.remove(child)
    parent.disabled.append(child)

    # Create a new node with a random activation function and add it to the tree
    new_node = NonterminalNode(random.choice(function_names))
    tree.hidden.append(new_node)
        
    # Update the parent/child relations
    parent.children.append(new_node)
    new_node.parents.append(parent)
    new_node.children.append(child)
    child.parents.append(new_node)
        
    # Update the number of enabled connections
    tree.n_connections += 1

    
# Adding a connection in the tree
def _add_connection(tree):
    # Function that checks if node1 is an ancestor of node 2
    def _is_ancestor(node1, node2):
        return node1 in [node2] + node2.parents or any([_is_ancestor(node1, node3) for node3 in node2.parents])
    
    # Find a valid parent-child pair for a connection, respecting the feed-forward property (no loops)
    valid_connection = False
    n_attempts = 0
    while not valid_connection and n_attempts < 500:
        parent = random.choice(tree.outputs + tree.hidden)
        child = random.choice(tree.inputs + tree.hidden)
        valid_connection = not _is_ancestor(child, parent) # Connection is valid if child is not an ancestor of parent
        n_attempts += 1
    
    # If a valid connection was found, update the parent/child relations and the number of enabled connections
    if n_attempts < 500:
        child.parents.append(parent)
        parent.children.append(child)
        tree.n_connections += 1


# Enable a disabled connection (created during initialization or when adding a node)
def _enable_connection(tree):
    # Check if there are any nodes with disabled connections; if not, do nothing
    parents_with_disabled = [node for node in tree.outputs + tree.hidden if node.disabled]
    if not parents_with_disabled:
        return
    
    # Select a random disabled parent-child pair
    parent = random.choice(parents_with_disabled)
    child = random.choice(parent.disabled)
    
    # Enable the corresponding connection and update the number of enabled connections
    parent.disabled.remove(child)
    parent.children.append(child)
    tree.n_connections += 1

In [19]:
def mutate(tree):
    # Copy the parent tree
    tree = toolbox.clone(tree)
        
    # Create lists of the various mutation functions and the corresponding probabilities
    mutation_functions = [_mutate_activation, _add_node, _add_connection, _enable_connection]
    probabilities = [P_MUTATE_ACTIVATION, P_ADD_NODE, P_ADD_CONNECTION, P_ENABLE_CONNECTION]
    
    # Ensure probabilities sum to 1
    assert sum(probabilities) == 1, "Mutation probabilities should sum to 1"
    
    # Choose a mutation function using the provided probabilities and execute it
    mutation_function, = random.choices(mutation_functions, probabilities, k=1)
    mutation_function(tree)
    
    # Return the resulting tree
    return tree,

In [20]:
# Add the mutate function to the toolbox
toolbox.register("mutate", mutate)

## Defining statistics

In [21]:
# Describe which kinds of statistics to keep track of
stats_avgfit = tools.Statistics(key = lambda ind: ind.fitness.values[0])
stats_bestfit = tools.Statistics(key = lambda ind: ind.fitness.values[1])
stats_connections = tools.Statistics(key = lambda ind: ind.n_connections)
stats_hidden = tools.Statistics(key = lambda ind: len(ind.hidden))

# Combine statistics into a single multistatistics object
mstats = tools.MultiStatistics(avg_fitness=stats_avgfit, best_fitness=stats_bestfit, 
                               hidden=stats_hidden, connections=stats_connections)

In [22]:
# Describe metrics to keep track of for each statistic
mstats.register("avg", np.mean)
mstats.register("std", np.std)
mstats.register("min", np.min)

## Running the genetic programming algorithm

##### Defining the algorithm

In [23]:
def eaWann(population, toolbox, ngen, cull_ratio, elite_ratio, stats=None, halloffame=None, verbose=True):
    """
    Evolutionary algorithm for Weight Agnostic Neural Networks (WANNs)
    Based on the algorithms provided by DEAP, as well as the WANN implementation
    
    The basic idea is as follows:
    1. Sort the population based on fitness (using NSGA2)
    2. Remove the worst individuals
    3. Copy the best individuals directly to the next generation
    4. Perform tournament selection to create the remaining offspring
    5. Evaluate all individuals in the new population (each individual is tested on the same sample)
    6. Repeat from 1
    
    Parameters:
    population: the intial population
    toolbox: the DEAP toolbox containing functions for parent selection, mutation etc.
    ngen: number of generations to run the algorithm for
    cull_ratio: fraction of the population that will be thrown away every generation (worst individuals)
    elite_ratio: fraction of the population that will be directly copied to the next generation (best individuals)
    stats: (Multi)Statistics object, keeping track of evolution statistics
    halloffame: List containing the best individuals that ever lived
    verbose: whether or not to print statistics
    
    """
    
    # Initialize logbook and set the correct headers
    logbook = tools.Logbook()
    logbook.header = ['gen'] + ['avg_fitness', 'best_fitness', 'connections', 'hidden'] if stats else []
    for field in stats.fields:
        if "fitness" in field:
            logbook.chapters[field].header = "min", "avg", "std"
        else:
            logbook.chapters[field].header = "avg", "std"

    # Evaluate all individuals using the same sample
    sample = toolbox.get_sample(0)
    fitnesses = toolbox.map(partial(toolbox.evaluate, sample_indices=sample), population)
    for ind, fit in zip(population, fitnesses):
        ind.fitness.values = fit
        ind.born = 0

    # Update hall of fame
    if halloffame is not None:
        halloffame.update(population)
        
    # Record and print performance if applicable
    record = stats.compile(population) if stats else {}
    logbook.record(gen=0, **record)
    if verbose:
        print(logbook.stream)

    # Begin the generational process
    for gen in range(1, ngen + 1): 
        
        # Initialize offspring and determine offspring size
        offspring = []
        population_size = len(population)
        
        # Rank the population and update rank values (tournament selection prefers individuals with bigger rank)
        ranked_population = tools.selNSGA2(population, population_size)
        for ind, rank in zip(ranked_population, reversed(range(population_size))):
            ind.rank = rank
        
        # Culling - remove worst performing individuals
        number_to_cull = int(cull_ratio*population_size)
        ranked_population = ranked_population[:population_size-number_to_cull]
        
        # Elitism - select and copy best performing individuals
        number_of_elites = int(elite_ratio*population_size)
        for i in range(number_of_elites):
            copy = toolbox.clone(ranked_population[i])
            del copy.fitness.values # Will be re-evaluated using this generation's sample
            offspring.append(copy)
            
        # Compute number of offspring that still need to be generated
        offspring_to_generate = population_size - number_of_elites
            
        # Select parents via (NSGA2 rank-based) tournament selection
        parents = toolbox.select(ranked_population, offspring_to_generate)
        
        # Mutate parents to obtain children
        for parent in parents:
            child, = toolbox.mutate(parent)
            del child.fitness.values
            child.born = gen
            offspring.append(child)

        # Evaluate all individuals in the offspring using the same sample
        sample = toolbox.get_sample(gen)
        fitnesses = toolbox.map(partial(toolbox.evaluate, sample_indices=sample), offspring)
        for ind, fit in zip(offspring, fitnesses):
            ind.fitness.values = fit

        # Update the hall of fame with the generated individuals
        if halloffame is not None:
            halloffame.update(offspring)

        # Replace the current population by the offspring
        population[:] = offspring

        # Append the current generation statistics to the logbook
        record = stats.compile(population) if stats else {}
        logbook.record(gen=gen, **record)
        if verbose:
            print(logbook.stream)

    return population, logbook

##### Running the algorithm

In [None]:
# Run the evolutionary algorithm
pop = toolbox.population(POPULATION_SIZE)
hof = tools.HallOfFame(10)
pop, log = eaWann(population=pop, toolbox=toolbox, ngen=N_GENERATIONS, cull_ratio=CULL_RATIO, 
                  elite_ratio=ELITE_RATIO, stats=mstats, halloffame=hof, verbose=True)

   	      avg_fitness       	      best_fitness      	  connections  	   hidden  
   	------------------------	------------------------	---------------	-----------
gen	min    	avg    	std     	min    	avg   	std     	avg    	std    	avg	std
0  	3.20063	4.25901	0.529687	2.38362	3.0483	0.278345	127.792	10.5214	0  	0  
1  	3.1211 	3.52496	0.310917	2.27684	2.55651	0.130452	118.78 	10.6604	0.208	0.405877
2  	3.01676	3.40438	0.33538 	2.26767	2.4417 	0.0766381	117.288	6.8303 	0.608	0.649874
3  	2.9871 	3.46817	0.310269	2.2516 	2.37516	0.0748565	115.252	7.92947	1.568	0.708079
4  	2.97715	3.37534	0.29308 	2.21707	2.39658	0.057759 	116.288	7.89817	2.408	0.744   
5  	2.95145	3.26687	0.273054	2.14717	2.358  	0.0978474	118.444	7.82808	2.692	0.914951
6  	2.89898	3.30452	0.217622	2.11131	2.30187	0.112612 	119.952	9.08348	3.592	1.01663 
7  	2.86721	3.1897 	0.221555	2.11939	2.30278	0.086225 	121.916	7.89816	3.784	0.988607
8  	2.89904	3.18042	0.202987	2.17339	2.29641	0.0752939	125.492	6.70775	4.564	0.96

92 	2.08328	2.28031	0.166142	1.44254	1.64381	0.0934067	183.572	2.26204	22.236	0.91886 
93 	2.04647	2.24177	0.219722	1.56254	1.68561	0.0740485	183.544	2.15594	22.808	0.816784
94 	2.0017 	2.24296	0.364008	1.50621	1.65291	0.131719 	183.652	2.1769 	22.836	0.795678
95 	2.00553	2.23481	0.257758	1.38919	1.58375	0.110183 	184.488	2.27549	23.088	0.967603
96 	1.90886	2.24602	0.199281	1.34248	1.49971	0.110612 	184.468	2.46759	23.584	1.04448 
97 	1.96312	2.22615	0.228906	1.42677	1.56574	0.0977747	184.224	1.93231	23.244	0.681516
98 	1.97045	2.28567	0.646537	1.65301	1.73604	0.0978543	184.532	2.11494	23.368	0.632911
99 	1.89157	2.18661	0.449118	1.43368	1.58947	0.123501 	185.94 	2.35635	23.692	0.655085
100	1.85214	2.24405	0.359259	1.47896	1.59042	0.096511 	185.804	1.71277	23.78 	0.628967
101	1.90353	2.22794	0.198092	1.47445	1.61607	0.0903438	185.324	1.57322	23.736	0.665059
102	1.96653	2.23644	0.213139	1.41782	1.56603	0.101848 	186.92 	1.38333	24.04 	0.793977
103	1.85666	2.2214 	0.31177 	1.26506	1.4497

186	1.56881	1.73692	0.357015	1.24218	1.34666	0.102036 	227.408	2.14325 	34.688	0.945863
187	1.62067	1.85282	0.576574	1.22709	1.36182	0.135548 	228.292	2.22323 	34.66 	1.37419 
188	1.50844	1.78062	0.499651	1.17326	1.30463	0.118223 	229.756	2.41008 	35.668	1.92192 
189	1.62753	1.78482	0.320235	1.40905	1.49535	0.0738453	230.04 	2.47677 	35.756	2.09773 
190	1.55907	1.69541	0.261254	1.34332	1.43397	0.0809894	233.4  	1.58997 	38.112	1.32191 
191	1.6258 	1.77743	0.3989  	1.35893	1.46313	0.0727915	233.716	1.82739 	37.968	1.66823 
192	1.57612	1.73372	0.238283	1.22779	1.35545	0.112153 	234.62 	1.06752 	38.872	0.758694
193	1.54766	1.6695 	0.288312	1.27844	1.39814	0.0889206	235.328	0.914558	39.352	0.596738
194	1.6005 	1.73394	0.352963	1.13224	1.29867	0.11072  	235.728	0.768125	39.292	0.471949
195	1.55965	1.6589 	0.258045	1.23725	1.36301	0.0961304	236.22 	1.08977 	39.784	0.658289
196	1.49872	1.64872	0.265866	1.21653	1.33153	0.124781 	236.2  	0.774597	40.036	0.683157
197	1.53301	1.65659	0.37467 	1.2

##### Wrapping up

In [None]:
# Store results
with open(RESULTS_FILENAME, "wb") as f:
    pickle.dump((pop, log, hof), f)

## Plotting statistics

In [None]:
# Extract generation IDs, minimum fitnesses and average total heights per generation
gen = log.select("gen")
fitness_best = log.chapters["avg_fitness"].select("min") 
conn_avg = log.chapters["connections"].select("avg")

In [None]:
# Plot line for minimum fitness
fig, fit_ax = plt.subplots()
fit_line = fit_ax.plot(gen, fitness_best, "b-", label="Best Fitness")
fit_ax.set_xlabel("Generation")
fit_ax.set_ylabel(f"Cross-entropy Loss", color="b")
for tl in fit_ax.get_yticklabels():
    tl.set_color("b")

# Plot line for average total height
height_ax = fit_ax.twinx()
height_line = height_ax.plot(gen, conn_avg, "r-", label="Average Number of Connections")
height_ax.set_ylabel("# Connections", color="r")
for tl in height_ax.get_yticklabels():
    tl.set_color("r")

# Add legend
lines = fit_line + height_line
labs = [l.get_label() for l in lines]
fit_ax.legend(lines, labs, bbox_to_anchor=(0.8, -0.15))

# Show the result
plt.show()

## Inspecting the best individual

In [None]:
best_ind = hof[0]
print(best_ind)
print(f"Fitness of best individual: {best_ind.fitness}")

##### Printing trees

In [None]:
# Print the trees of the best individual
for string in best_ind.get_strings():
    print(f"{string}\n")

##### Computing training & validation accuracy

In [None]:
# Retrieving predictions from an individual
def get_predictions(individual, X, weight):
    func = compile_multiclasstree(individual)
    predictions = []
    for i in range(X.shape[0]):
        outputs_i = func(weight, X[i])
        predictions.append(np.argmax(outputs_i))
    return predictions

In [None]:
# Compute accuracy of predictions
def compute_accuracy(Y_pred, Y_true):
    n_correct = np.sum(Y_pred == Y_true)
    return n_correct/Y_true.shape[0]

In [None]:
# Retrieve predictions of the best individual on the training and validation sets, for all weights
Y_train_pred = np.array([get_predictions(best_ind, X_train, w) for w in WEIGHTS_TO_TEST])
Y_test_pred = np.array([get_predictions(best_ind, X_test, w) for w in WEIGHTS_TO_TEST])

In [None]:
# Compute training accuracies for all weights and extract the best weight
train_accs = [compute_accuracy(Y_train_pred[i], Y_train) for i in range(len(WEIGHTS_TO_TEST))]
test_accs = [compute_accuracy(Y_test_pred[i], Y_test) for i in range(len(WEIGHTS_TO_TEST))]
best_weight_idx_train = np.argmax(train_accs)
best_weight_idx_test = np.argmax(test_accs)

# Print best training and validation accuracies of the best individual
print(f"Best training accuracy (weight {WEIGHTS_TO_TEST[best_weight_idx_train]}): {np.max(train_accs)}")
print(f"Best validation accuracy (weight {WEIGHTS_TO_TEST[best_weight_idx_test]}): {np.max(test_accs)}")

In [None]:
# Obtain majority votes (in case of a tie, use vote of the best classifier)
def get_majority_predictions(predictions):
    def _majority(l):
        return max(set(l), key=l.count)
    predictions = [_majority(list(predictions[:,i])) for i in range(predictions.shape[1])]
    return predictions

In [None]:
# Compute majority predictions for train and validation sets
Y_train_majpred = get_majority_predictions(Y_train_pred)
Y_test_majpred = get_majority_predictions(Y_test_pred)

# Print the accuracies of the majority votes
print(f"Majority training accuracy: {compute_accuracy(Y_train_majpred, Y_train)}")
print(f"Majority validation accuracy: {compute_accuracy(Y_test_majpred, Y_test)}")

##### Confusion matrices

In [None]:
# Compute confusion matrices
cm_train_best = confusion_matrix(Y_train, Y_train_pred[best_weight_idx_train], labels=range(N_CLASSES_TO_USE))
cm_test_best = confusion_matrix(Y_test, Y_test_pred[best_weight_idx_test], labels=range(N_CLASSES_TO_USE))
cm_train_maj = confusion_matrix(Y_train, Y_train_majpred, labels=range(N_CLASSES_TO_USE))
cm_test_maj = confusion_matrix(Y_test, Y_test_majpred, labels=range(N_CLASSES_TO_USE))

In [None]:
# Create figure
fig, ax = plt.subplots(2,2, figsize=(11,10))
ax = ax.ravel()

# Plot confusion matrix for training data using best weight
sns.heatmap(cm_train_best, annot=True, fmt='g', ax=ax[0], cmap="Blues")
ax[0].set_xlabel('Predicted labels')
ax[0].set_ylabel('True labels')
ax[0].set_title('Confusion matrix for training data (best weight)')

# Plot confusion matrix for validation data using best weight
sns.heatmap(cm_test_best, annot=True, fmt='g', ax=ax[1], cmap="Blues")
ax[1].set_xlabel('Predicted labels')
ax[1].set_ylabel('True labels')
ax[1].set_title('Confusion matrix for validation data (best weight)')

# Plot confusion matrix for training data using majority vote
sns.heatmap(cm_train_maj, annot=True, fmt='g', ax=ax[2], cmap="Blues")
ax[2].set_xlabel('Predicted labels')
ax[2].set_ylabel('True labels')
ax[2].set_title('Confusion matrix for training data (majority vote)')

# Plot confusion matrix for validation data using majority vote
sns.heatmap(cm_test_maj, annot=True, fmt='g', ax=ax[3], cmap="Blues")
ax[3].set_xlabel('Predicted labels')
ax[3].set_ylabel('True labels')
ax[3].set_title('Confusion matrix for validation data (majority vote)')

# Show the result
plt.show()

##### Used features

In [None]:
# Extract input tallies from best individual
input_tallies = []
for tree in best_ind.get_strings():
    inputs_used = list(map(int, re.findall("[0-9]+", tree)))
    input_tally = np.zeros(X_train.shape[1])
    for arg in inputs_used:
        input_tally[arg] += 1
    input_tallies.append(input_tally)

# Create plots of the inputs (pixels) used in the tree of the best individual for each class
# Note: for the best-looking plot, this implementation assumes that N_CLASSES_TO_USE is set to 10
fig, ax = plt.subplots(2, 5, figsize=(16,5))
ax = ax.ravel()
for i, tally in enumerate(input_tallies):
    img_shape = int(math.sqrt(X_train.shape[1]))
    ax[i].imshow(np.array(tally).reshape(img_shape, img_shape), clim=(0,np.max(input_tallies)))
    ax[i].axis("off")
    ax[i].set_title(f"Pixels used for class {i}")
plt.show()

In [None]:
# Plot image of all inputs used in the tree
plt.imshow(np.sum(input_tallies, axis=0).reshape(16,16))
plt.axis("off")
plt.title("Pixels used in the entire tree")
plt.show()