diff --git a/gp/ea/mu_plus_lambda.py b/gp/ea/mu_plus_lambda.py index df3e880c..afe2dd51 100644 --- a/gp/ea/mu_plus_lambda.py +++ b/gp/ea/mu_plus_lambda.py @@ -24,7 +24,7 @@ def __init__( tournament_size: int, *, n_processes: int = 1, - local_search: Callable[[List[Individual]], None] = lambda combined: None + local_search: Callable[[Individual], None] = lambda combined: None ): """Init function @@ -39,7 +39,7 @@ def __init__( n_processes : int, optional Number of parallel processes to be used. If greater than 1, parallel evaluation of the objective is supported. Defaults to 1. - local_search : Callable[[List[gp.Individua]], None], optional + local_search : Callable[[Individua], None], optional Called before each fitness evaluation with a joint list of offsprings and parents to optimize numeric leaf values of the graph. Defaults to identity function. @@ -106,7 +106,8 @@ def step(self, pop: Population, objective: Callable[[Individual], Individual]) - # population instead of the other way around combined = offsprings + pop.parents - self.local_search(combined) + for ind in combined: + self.local_search(ind) combined = self._compute_fitness(combined, objective) diff --git a/gp/local_search/gradient_based.py b/gp/local_search/gradient_based.py index c2b5792a..10e4f4d1 100644 --- a/gp/local_search/gradient_based.py +++ b/gp/local_search/gradient_based.py @@ -8,28 +8,28 @@ except ModuleNotFoundError: torch_available = False -from typing import Callable, List, Optional +from typing import Callable, Optional from ..individual import Individual # noqa: F401 def gradient_based( - individuals: List[Individual], + individual: Individual, objective: Callable[[torch.nn.Module], torch.Tensor], lr: float, gradient_steps: int, optimizer: Optional[Optimizer] = None, clip_value: Optional[float] = None, ) -> None: - """Perform a local search for numeric leaf values for the list of - individuals based on gradient information obtained via automatic + """Perform a local search for numeric leaf values for an individual + based on gradient information obtained via automatic differentiation. Parameters ---------- - individuals : List - List of individuals for which to perform local search. + individual : Individual + Individual for which to perform local search. objective : Callable Objective function that is called with a differentiable graph and returns a differentiable loss. @@ -55,23 +55,22 @@ def gradient_based( if clip_value is None: clip_value = 0.1 * 1.0 / lr - for ind in individuals: - f = ind.to_torch() + f = individual.to_torch() - if len(list(f.parameters())) > 0: - optimizer = optimizer_class(f.parameters(), lr=lr) + if len(list(f.parameters())) > 0: + optimizer = optimizer_class(f.parameters(), lr=lr) - for i in range(gradient_steps): - loss = objective(f) - if not torch.isfinite(loss): - continue + for i in range(gradient_steps): + loss = objective(f) + if not torch.isfinite(loss): + continue - f.zero_grad() - loss.backward() - if clip_value is not np.inf: - torch.nn.utils.clip_grad.clip_grad_value_(f.parameters(), clip_value) - optimizer.step() + f.zero_grad() + loss.backward() + if clip_value is not np.inf: + torch.nn.utils.clip_grad.clip_grad_value_(f.parameters(), clip_value) + optimizer.step() - assert all(torch.isfinite(t) for t in f.parameters()) + assert all(torch.isfinite(t) for t in f.parameters()) - ind.update_parameters_from_torch_class(f) + individual.update_parameters_from_torch_class(f) diff --git a/test/test_local_search.py b/test/test_local_search.py index 0b4feb19..5ccfa955 100644 --- a/test/test_local_search.py +++ b/test/test_local_search.py @@ -20,15 +20,15 @@ def objective(f): # test increase parameter value if too small ind.parameter_names_to_values[""] = 0.9 - gp.local_search.gradient_based([ind], objective, 0.05, 1) + gp.local_search.gradient_based(ind, objective, 0.05, 1) assert ind.parameter_names_to_values[""] == pytest.approx(0.91) # test decrease parameter value if too large ind.parameter_names_to_values[""] = 1.1 - gp.local_search.gradient_based([ind], objective, 0.05, 1) + gp.local_search.gradient_based(ind, objective, 0.05, 1) assert ind.parameter_names_to_values[""] == pytest.approx(1.09) # test no change of parameter value if at optimum ind.parameter_names_to_values[""] = 1.0 - gp.local_search.gradient_based([ind], objective, 0.05, 1) + gp.local_search.gradient_based(ind, objective, 0.05, 1) assert ind.parameter_names_to_values[""] == pytest.approx(1.0)