Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions autofit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from autoconf.dictable import register_parser
from . import conf

conf.instance.register(__file__)

import abc
import pickle

from dill import register

from autoconf.dictable import register_parser
from .non_linear.grid.grid_search import GridSearch as SearchGridSearch
from . import conf
from . import exc
from . import mock as m
from .non_linear.grid.grid_search import GridSearch as SearchGridSearch
from .aggregator.base import AggBase
from .database.aggregator.aggregator import GridSearchAggregator
from .graphical.expectation_propagation.history import EPHistory
Expand Down Expand Up @@ -136,6 +138,6 @@ def save_abc(pickler, obj):
pickle._Pickler.save_type(pickler, obj)


conf.instance.register(__file__)


__version__ = "2025.5.10.1"
2 changes: 2 additions & 0 deletions autofit/config/general.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
jax:
use_jax: false # If True, PyAutoFit uses JAX internally, whereas False uses normal Numpy.
analysis:
n_cores: 1 # The number of cores a parallelized sum of Analysis classes uses by default.
hpc:
Expand Down
25 changes: 6 additions & 19 deletions autofit/graphical/factor_graphs/factor.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
from copy import deepcopy
from inspect import getfullargspec
import jax
from typing import Tuple, Dict, Any, Callable, Union, List, Optional, TYPE_CHECKING

import numpy as np

try:
import jax

_HAS_JAX = True
except ImportError:
_HAS_JAX = False

from autofit.graphical.utils import (
nested_filter,
to_variabledata,
Expand Down Expand Up @@ -294,13 +288,7 @@ def _set_jacobians(
self._vjp = vjp
self._jacfwd = jacfwd
if vjp or factor_vjp:
if factor_vjp:
self._factor_vjp = factor_vjp
elif not _HAS_JAX:
raise ModuleNotFoundError(
"jax needed if `factor_vjp` not passed with vjp=True"
)

self._factor_vjp = factor_vjp
self.func_jacobian = self._vjp_func_jacobian
else:
# This is set by default
Expand All @@ -312,11 +300,10 @@ def _set_jacobians(
self._jacobian = jacobian
elif numerical_jacobian:
self._factor_jacobian = self._numerical_factor_jacobian
elif _HAS_JAX:
if jacfwd:
self._jacobian = jax.jacfwd(self._factor, range(self.n_args))
else:
self._jacobian = jax.jacobian(self._factor, range(self.n_args))
elif jacfwd:
self._jacobian = jax.jacfwd(self._factor, range(self.n_args))
else:
self._jacobian = jax.jacobian(self._factor, range(self.n_args))

def _factor_value(self, raw_fval) -> FactorValue:
"""Converts the raw output of the factor into a `FactorValue`
Expand Down
8 changes: 0 additions & 8 deletions autofit/graphical/factor_graphs/jacobians.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
try:
import jax

_HAS_JAX = True
except ImportError:
_HAS_JAX = False


import numpy as np

from autoconf import cached_property
Expand Down
66 changes: 35 additions & 31 deletions autofit/jax_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,45 @@
"""
Allows the user to switch between using NumPy and JAX for linear algebra operations.

If USE_JAX=1 then JAX's NumPy is used, otherwise vanilla NumPy is used.
If USE_JAX=true in general.yaml then JAX's NumPy is used, otherwise vanilla NumPy is used.
"""
from os import environ
import jax

use_jax = environ.get("USE_JAX", "0") == "1"
from autoconf import conf

use_jax = conf.instance["general"]["jax"]["use_jax"]

if use_jax:
try:
import jax
from jax import numpy

def jit(function, *args, **kwargs):
return jax.jit(function, *args, **kwargs)
from jax import numpy

print(

"""
***JAX ENABLED***

Using JAX for grad/jit and GPU/TPU acceleration.
To disable JAX, set: config -> general -> jax -> use_jax = false
""")

def grad(function, *args, **kwargs):
return jax.grad(function, *args, **kwargs)
def jit(function, *args, **kwargs):
return jax.jit(function, *args, **kwargs)

def grad(function, *args, **kwargs):
return jax.grad(function, *args, **kwargs)

from jax._src.scipy.special import erfinv

print("JAX mode enabled")
except ImportError:
raise ImportError(
"JAX is not installed. Please install it with `pip install jax`."
)
else:

print(
"""
***JAX DISABLED***

Falling back to standard NumPy (no grad/jit or GPU support).
To enable JAX (if supported), set: config -> general -> jax -> use_jax = true
""")

import numpy # noqa
from scipy.special.cython_special import erfinv # noqa

Expand All @@ -33,20 +49,8 @@ def jit(function, *_, **__):
def grad(function, *_, **__):
return function

try:
from jax._src.tree_util import (
register_pytree_node_class as register_pytree_node_class,
register_pytree_node as register_pytree_node,
)
from jax._src.scipy.special import erfinv

except ImportError:

def register_pytree_node_class(cls):
return cls

def register_pytree_node(*_, **__):
def decorator(cls):
return cls
from jax._src.tree_util import (
register_pytree_node_class as register_pytree_node_class,
register_pytree_node as register_pytree_node,
)

return decorator
3 changes: 2 additions & 1 deletion autofit/mapper/prior/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
from abc import ABC, abstractmethod
from copy import copy
import jax
from typing import Union, Tuple, Optional, Dict

from autoconf import conf
Expand Down Expand Up @@ -120,7 +121,7 @@ def exception_message():
)

if jax_wrapper.use_jax:
import jax

jax.lax.cond(
jax.numpy.logical_or(
value < self.lower_limit,
Expand Down
20 changes: 19 additions & 1 deletion autofit/mapper/prior_model/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import inspect
import jax.numpy as jnp
import jax
import json
import logging
import random
Expand All @@ -12,6 +14,7 @@
from autoconf import conf
from autoconf.exc import ConfigException
from autofit import exc
from autofit import jax_wrapper
from autofit.mapper import model
from autofit.mapper.model import AbstractModel, frozen_cache
from autofit.mapper.prior import GaussianPrior
Expand Down Expand Up @@ -781,7 +784,22 @@ def instance_from_vector(self, vector, ignore_prior_limits=False):

if not ignore_prior_limits:
for prior, value in arguments.items():
prior.assert_within_limits(value)

if not jax_wrapper.use_jax:

prior.assert_within_limits(value)

else:

valid = prior.assert_within_limits(value)

return jax.lax.cond(
jnp.isnan(valid),
lambda _: jnp.nan, # or return -jnp.inf
lambda _: 0.0, # normal computation
operand=None,
)


return self.instance_for_arguments(
arguments,
Expand Down
8 changes: 2 additions & 6 deletions autofit/non_linear/fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,26 @@
from autoconf import conf
from autoconf import cached_property

from autofit import jax_wrapper
from autofit.jax_wrapper import numpy as np
from autofit import exc

from autofit.jax_wrapper import numpy as np

from autofit.mapper.prior_model.abstract import AbstractPriorModel
from autofit.non_linear.paths.abstract import AbstractPaths
from autofit.non_linear.analysis import Analysis

from timeout_decorator import timeout

from autofit import jax_wrapper


def get_timeout_seconds():

try:
return conf.instance["general"]["test"]["lh_timeout_seconds"]
except KeyError:
pass


timeout_seconds = get_timeout_seconds()


class Fitness:
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion autofit/non_linear/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def samples_from_model(
if os.environ.get("PYAUTOFIT_TEST_MODE") == "1" and test_mode_samples:
return self.samples_in_test_mode(total_points=total_points, model=model)

if jax_wrapper.use_jax:
if jax_wrapper.use_jax or n_cores == 1:
return self.samples_jax(
total_points=total_points,
model=model,
Expand Down
1 change: 1 addition & 0 deletions autofit/non_linear/parallel/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
job_queue: multiprocessing.Queue
The queue through which jobs are submitted
"""

super().__init__(name=name)
self.logger = logging.getLogger(
f"process {name}"
Expand Down
1 change: 1 addition & 0 deletions autofit/non_linear/parallel/sneaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, function, *args):
args
The arguments to that function
"""

super().__init__()
if _is_likelihood_function(function):
self.function = None
Expand Down
4 changes: 1 addition & 3 deletions autofit/non_linear/search/abstract_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import logging
import multiprocessing as mp
import os
import signal
import sys
import time
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -256,7 +254,6 @@ def __init__(

if jax_wrapper.use_jax:
self.number_of_cores = 1
logger.warning(f"JAX is enabled. Setting number of cores to 1.")

self.number_of_cores = number_of_cores

Expand Down Expand Up @@ -1197,6 +1194,7 @@ def make_sneaky_pool(self, fitness: Fitness) -> Optional[SneakyPool]:
-------
An implementation of a multiprocessing pool
"""

self.logger.warning(
"...using SneakyPool. This copies the likelihood function "
"to each process on instantiation to avoid copying multiple "
Expand Down
4 changes: 2 additions & 2 deletions autofit/non_linear/search/nest/dynesty/search/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def _fit(
set of accepted samples of the fit.
"""

from dynesty.pool import Pool

fitness = Fitness(
model=model,
analysis=analysis,
Expand Down Expand Up @@ -152,6 +150,8 @@ def _fit(
):
raise RuntimeError

from dynesty.pool import Pool

with Pool(
njobs=self.number_of_cores,
loglike=fitness,
Expand Down
6 changes: 4 additions & 2 deletions autofit/non_linear/search/nest/dynesty/search/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def search_internal_from(
in the dynesty queue for samples.
"""

gradient = fitness.grad if self.use_gradient else None

if checkpoint_exists:
search_internal = StaticSampler.restore(
fname=self.checkpoint_file, pool=pool
Expand All @@ -127,7 +129,7 @@ def search_internal_from(
self.write_uses_pool(uses_pool=True)
return StaticSampler(
loglikelihood=pool.loglike,
gradient=fitness.grad,
gradient=gradient,
prior_transform=pool.prior_transform,
ndim=model.prior_count,
live_points=live_points,
Expand All @@ -139,7 +141,7 @@ def search_internal_from(
self.write_uses_pool(uses_pool=False)
return StaticSampler(
loglikelihood=fitness,
gradient=fitness.grad,
gradient=gradient,
prior_transform=prior_transform,
ndim=model.prior_count,
logl_args=[model, fitness],
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"typing-inspect>=0.4.0",
"emcee>=3.1.6",
"gprof2dot==2021.2.21",
"jax==0.5.3",
"matplotlib",
"numpydoc>=1.0.0",
"pyprojroot==0.2.0",
Expand Down
2 changes: 2 additions & 0 deletions test_autofit/config/general.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
jax:
use_jax: false # If True, PyAutoFit uses JAX internally, whereas False uses normal Numpy.
analysis:
n_cores: 1 # The number of cores a parallelized sum of Analysis classes uses by default.
hpc:
Expand Down
1 change: 1 addition & 0 deletions test_autofit/config/non_linear/mcmc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Emcee:
ball_upper_limit: 0.51
method: prior
parallel:
force_x1_cpu: true
number_of_cores: 1
printing:
silence: false
Expand Down
2 changes: 1 addition & 1 deletion test_autofit/config/non_linear/nest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ DynestyStatic:
initialize:
method: prior
parallel:
force_x1_cpu: false
force_x1_cpu: true
number_of_cores: 1
printing:
silence: true
Expand Down
Loading
Loading