Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions autofit/non_linear/fitness.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import numpy as np

import os
from typing import Optional

from autoconf import conf

from autofit import exc

from autofit.jax_wrapper import numpy as np

from autofit.mapper.prior_model.abstract import AbstractPriorModel
from autofit.non_linear.paths.abstract import AbstractPaths
from autofit.non_linear.analysis import Analysis
Expand Down Expand Up @@ -154,9 +156,7 @@ def __call__(self, parameters, *kwargs):
try:
instance = self.model.instance_from_vector(vector=parameters)
log_likelihood = self.log_likelihood_function(instance=instance)

if np.isnan(log_likelihood):
return self.resample_figure_of_merit
log_likelihood = np.where(np.isnan(log_likelihood), self.resample_figure_of_merit, log_likelihood)

except exc.FitException:
return self.resample_figure_of_merit
Expand Down
114 changes: 95 additions & 19 deletions autofit/non_linear/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from autofit.mapper.prior_model.abstract import AbstractPriorModel
from autofit.non_linear.parallel import SneakyPool

from autofit import jax_wrapper

logger = logging.getLogger(__name__)


Expand All @@ -39,14 +41,14 @@ def figure_of_metric(args) -> Optional[float]:
return None

def samples_from_model(
self,
total_points: int,
model: AbstractPriorModel,
fitness,
paths: AbstractPaths,
use_prior_medians: bool = False,
test_mode_samples: bool = True,
n_cores: int = 1,
self,
total_points: int,
model: AbstractPriorModel,
fitness,
paths: AbstractPaths,
use_prior_medians: bool = False,
test_mode_samples: bool = True,
n_cores: int = 1,
):
"""
Generate the initial points of the non-linear search, by randomly drawing unit values from a uniform
Expand All @@ -64,6 +66,14 @@ def samples_from_model(
if os.environ.get("PYAUTOFIT_TEST_MODE") == "1" and test_mode_samples:
return self.samples_in_test_mode(total_points=total_points, model=model)

if jax_wrapper.use_jax:
return self.samples_jax(
total_points=total_points,
model=model,
fitness=fitness,
use_prior_medians=use_prior_medians
)

unit_parameter_lists = []
parameter_lists = []
figures_of_merit_list = []
Expand Down Expand Up @@ -92,30 +102,95 @@ def samples_from_model(
unit_parameter_lists_.append(unit_parameter_list)

for figure_of_merit, unit_parameter_list, parameter_list in zip(
sneaky_pool.map(
function=self.figure_of_metric,
args_list=[(fitness, parameter_list) for parameter_list in parameter_lists_],
log_info=False
),
unit_parameter_lists_,
parameter_lists_,
sneaky_pool.map(
function=self.figure_of_metric,
args_list=[(fitness, parameter_list) for parameter_list in parameter_lists_],
log_info=False
),
unit_parameter_lists_,
parameter_lists_,
):
if figure_of_merit is not None:
unit_parameter_lists.append(unit_parameter_list)
parameter_lists.append(parameter_list)
figures_of_merit_list.append(figure_of_merit)

if total_points > 1 and np.allclose(
a=figures_of_merit_list[0], b=figures_of_merit_list[1:]
a=figures_of_merit_list[0], b=figures_of_merit_list[1:]
):
raise exc.InitializerException(
"""
The initial samples all have the same figure of merit (e.g. log likelihood values).

The non-linear search will therefore not progress correctly.

Possible causes for this behaviour are:


- The `log_likelihood_function` of the analysis class is defined incorrectly.
- The model parameterization creates numerically inaccurate log likelihoods.
- The`log_likelihood_function` is always returning `nan` values.
"""
)

logger.info(f"Initial samples generated, starting non-linear search")

return unit_parameter_lists, parameter_lists, figures_of_merit_list

def samples_jax(
self,
total_points: int,
model: AbstractPriorModel,
fitness,
use_prior_medians: bool = False,
):
"""
Generate the initial points of the non-linear search, by randomly drawing unit values from a uniform
distribution between the ball_lower_limit and ball_upper_limit values.

Parameters
----------
total_points
The number of points in non-linear paramemter space which initial points are created for.
model
An object that represents possible instances of some model with a given dimensionality which is the number
of free dimensions of the model.
"""

unit_parameter_lists = []
parameter_lists = []
figures_of_merit_list = []

logger.info(f"Generating initial samples of model using JAX LH Function cores")

while len(figures_of_merit_list) < total_points:

if not use_prior_medians:
unit_parameter_list = self._generate_unit_parameter_list(model)
else:
unit_parameter_list = [0.5] * model.prior_count

parameter_list = model.vector_from_unit_vector(
unit_vector=unit_parameter_list
)

figure_of_merit = self.figure_of_metric((fitness, parameter_list))

if figure_of_merit is not None:
unit_parameter_lists.append(unit_parameter_list)
parameter_lists.append(parameter_list)
figures_of_merit_list.append(figure_of_merit)

if total_points > 1 and np.allclose(
a=figures_of_merit_list[0], b=figures_of_merit_list[1:]
):
raise exc.InitializerException(
"""
The initial samples all have the same figure of merit (e.g. log likelihood values).

The non-linear search will therefore not progress correctly.

Possible causes for this behaviour are:

- The `log_likelihood_function` of the analysis class is defined incorrectly.
- The model parameterization creates numerically inaccurate log likelihoods.
- The`log_likelihood_function` is always returning `nan` values.
Expand Down Expand Up @@ -321,6 +396,7 @@ def info_value_from(self, value : Tuple[float, float]) -> float:
"""
return (value[1] + value[0]) / 2.0


class Initializer(AbstractInitializer):
def __init__(self, lower_limit: float, upper_limit: float):
"""
Expand Down
Loading