Skip to content

Commit

Permalink
Merge pull request #169 from jakobj/maint/refactor-mu-plus-lambda-sort
Browse files Browse the repository at this point in the history
Simplify sorting of individuals in `mu_plus_lambda`
  • Loading branch information
mschmidt87 committed Jul 4, 2020
2 parents a842483 + 9a13d85 commit 69dd1ca
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 20 deletions.
26 changes: 8 additions & 18 deletions cgp/ea/mu_plus_lambda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import concurrent.futures
import numpy as np

from typing import Callable, List, Tuple, Union
from typing import Callable, List, Union

from ..individual import IndividualBase
from ..population import Population
Expand Down Expand Up @@ -168,33 +168,23 @@ def _compute_fitness(
return combined

def _sort(self, combined: List[IndividualBase]) -> List[IndividualBase]:
# create copy of population
combined_copy = [ind.clone() for ind in combined]
def sort_func(ind: IndividualBase) -> float:
"""Return fitness of an individual, return -infinity for an individual
with fitness equal nan, or raise error if the fitness is
not a float.
# replace all nan by -inf to make sure they end up at the end
# after sorting
for ind in combined_copy:
"""
if np.isnan(ind.fitness):
ind.fitness = -np.inf
return -np.inf

def sort_func(zipped_ind: Tuple[int, IndividualBase]) -> float:
"""Return fitness of an individual or raise error if it is None.
"""
_, ind = zipped_ind
if isinstance(ind.fitness, float):
return ind.fitness
else:
raise ValueError(
f"IndividualBase fitness value is of wrong type {type(ind.fitness)}."
)

# get list of indices that sorts combined_copy ("argsort") in descending order
combined_sorted_indices = [
idx for (idx, _) in sorted(enumerate(combined_copy), key=sort_func, reverse=True)
]

# return original list of individuals sorted in descending order
return [combined[idx] for idx in combined_sorted_indices]
return sorted(combined, key=sort_func, reverse=True)

def _create_new_parent_population(
self, n_parents: int, combined: List[IndividualBase]
Expand Down
6 changes: 4 additions & 2 deletions test/test_ea_mu_plus_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def objective_with_label(individual, label):
assert pop.champion.fitness == pytest.approx(-1.0)


def test_fitness_contains_nan(population_params, genome_params):
def test_fitness_contains_and_maintains_nan(population_params, genome_params):
def objective(individual):
if np.random.rand() < 0.5:
if np.random.rand() < 0.95:
individual.fitness = np.nan
else:
individual.fitness = np.random.rand()
Expand All @@ -42,6 +42,8 @@ def objective(individual):
ea.initialize_fitness_parents(pop, objective)
ea.step(pop, objective)

assert np.nan in [ind.fitness for ind in pop]


def test_offspring_individuals_are_assigned_correct_indices(population_params, genome_params):
def objective(ind):
Expand Down

0 comments on commit 69dd1ca

Please sign in to comment.