diff --git a/cgp/ea/mu_plus_lambda.py b/cgp/ea/mu_plus_lambda.py index afab0d04..399850b1 100644 --- a/cgp/ea/mu_plus_lambda.py +++ b/cgp/ea/mu_plus_lambda.py @@ -1,7 +1,7 @@ import concurrent.futures import numpy as np -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Union from ..individual import IndividualBase from ..population import Population @@ -168,19 +168,15 @@ def _compute_fitness( return combined def _sort(self, combined: List[IndividualBase]) -> List[IndividualBase]: - # create copy of population - combined_copy = [ind.clone() for ind in combined] + def sort_func(ind: IndividualBase) -> float: + """Return fitness of an individual, return -infinity for an individual + with fitness equal nan, or raise error if the fitness is + not a float. - # replace all nan by -inf to make sure they end up at the end - # after sorting - for ind in combined_copy: + """ if np.isnan(ind.fitness): - ind.fitness = -np.inf + return -np.inf - def sort_func(zipped_ind: Tuple[int, IndividualBase]) -> float: - """Return fitness of an individual or raise error if it is None. - """ - _, ind = zipped_ind if isinstance(ind.fitness, float): return ind.fitness else: @@ -188,13 +184,7 @@ def sort_func(zipped_ind: Tuple[int, IndividualBase]) -> float: f"IndividualBase fitness value is of wrong type {type(ind.fitness)}." ) - # get list of indices that sorts combined_copy ("argsort") in descending order - combined_sorted_indices = [ - idx for (idx, _) in sorted(enumerate(combined_copy), key=sort_func, reverse=True) - ] - - # return original list of individuals sorted in descending order - return [combined[idx] for idx in combined_sorted_indices] + return sorted(combined, key=sort_func, reverse=True) def _create_new_parent_population( self, n_parents: int, combined: List[IndividualBase] diff --git a/test/test_ea_mu_plus_lambda.py b/test/test_ea_mu_plus_lambda.py index 8a258a4d..48ae79d6 100644 --- a/test/test_ea_mu_plus_lambda.py +++ b/test/test_ea_mu_plus_lambda.py @@ -28,9 +28,9 @@ def objective_with_label(individual, label): assert pop.champion.fitness == pytest.approx(-1.0) -def test_fitness_contains_nan(population_params, genome_params): +def test_fitness_contains_and_maintains_nan(population_params, genome_params): def objective(individual): - if np.random.rand() < 0.5: + if np.random.rand() < 0.95: individual.fitness = np.nan else: individual.fitness = np.random.rand() @@ -42,6 +42,8 @@ def objective(individual): ea.initialize_fitness_parents(pop, objective) ea.step(pop, objective) + assert np.nan in [ind.fitness for ind in pop] + def test_offspring_individuals_are_assigned_correct_indices(population_params, genome_params): def objective(ind):