From 2042586ce3f854a88d37d0372d9651447cc05c36 Mon Sep 17 00:00:00 2001 From: Jakob Jordan Date: Wed, 1 Jul 2020 22:21:39 +0200 Subject: [PATCH] Simplify sorting of individuals in mu_plus_lambda --- cgp/ea/mu_plus_lambda.py | 26 ++++++++------------------ test/test_ea_mu_plus_lambda.py | 6 ++++-- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/cgp/ea/mu_plus_lambda.py b/cgp/ea/mu_plus_lambda.py index d3623bd3..8f13afde 100644 --- a/cgp/ea/mu_plus_lambda.py +++ b/cgp/ea/mu_plus_lambda.py @@ -1,7 +1,7 @@ import concurrent.futures import numpy as np -from typing import Callable, List, Tuple +from typing import Callable, List from ..individual import IndividualBase from ..population import Population @@ -151,19 +151,15 @@ def _compute_fitness( return combined def _sort(self, combined: List[IndividualBase]) -> List[IndividualBase]: - # create copy of population - combined_copy = [ind.clone() for ind in combined] + def sort_func(ind: IndividualBase) -> float: + """Return fitness of an individual, return -infinity for an individual + with fitness equal nan, or raise error if the fitness is + not a float. - # replace all nan by -inf to make sure they end up at the end - # after sorting - for ind in combined_copy: + """ if np.isnan(ind.fitness): - ind.fitness = -np.inf + return -np.inf - def sort_func(zipped_ind: Tuple[int, IndividualBase]) -> float: - """Return fitness of an individual or raise error if it is None. - """ - _, ind = zipped_ind if isinstance(ind.fitness, float): return ind.fitness else: @@ -171,13 +167,7 @@ def sort_func(zipped_ind: Tuple[int, IndividualBase]) -> float: f"IndividualBase fitness value is of wrong type {type(ind.fitness)}." ) - # get list of indices that sorts combined_copy ("argsort") in descending order - combined_sorted_indices = [ - idx for (idx, _) in sorted(enumerate(combined_copy), key=sort_func, reverse=True) - ] - - # return original list of individuals sorted in descending order - return [combined[idx] for idx in combined_sorted_indices] + return sorted(combined, key=sort_func, reverse=True) def _create_new_parent_population( self, n_parents: int, combined: List[IndividualBase] diff --git a/test/test_ea_mu_plus_lambda.py b/test/test_ea_mu_plus_lambda.py index a7dd56c3..fa45c539 100644 --- a/test/test_ea_mu_plus_lambda.py +++ b/test/test_ea_mu_plus_lambda.py @@ -28,9 +28,9 @@ def objective_with_label(individual, label): assert pop.champion.fitness == pytest.approx(-1.0) -def test_fitness_contains_nan(population_params, genome_params): +def test_fitness_contains_and_maintains_nan(population_params, genome_params): def objective(individual): - if np.random.rand() < 0.5: + if np.random.rand() < 0.95: individual.fitness = np.nan else: individual.fitness = np.random.rand() @@ -41,3 +41,5 @@ def objective(individual): ea = cgp.ea.MuPlusLambda(10, 10, 1) ea.initialize_fitness_parents(pop, objective) ea.step(pop, objective) + + assert np.nan in [ind.fitness for ind in pop]