Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1e96fd5
fix JAX env error
Oct 20, 2025
0f8ae89
'Updated version in __init__ to 2025.10.20.2
rhayes777 Oct 20, 2025
eb17408
'Updated version in __init__ to 2025.10.20.4
rhayes777 Oct 20, 2025
ef79887
'Updated version in __init__ to 2025.10.20.5
rhayes777 Oct 20, 2025
53ba169
floats on hpc mode iterations
Oct 21, 2025
38c94e4
Merge branch 'feature/build_fixes' of github.com:rhayes777/PyAutoFit …
Oct 21, 2025
1890a0f
import exception on emcee
Oct 21, 2025
7d1bf63
fix folder making in fits agg
Oct 21, 2025
846da82
erm
Oct 21, 2025
6797536
conflicts
Oct 21, 2025
bf92432
'Updated version in __init__ to 2025.10.21.1
rhayes777 Oct 21, 2025
f58ef53
license in docs
Oct 27, 2025
8506ac5
space
Oct 29, 2025
f298cd8
small workflow fixes
Nov 4, 2025
b8f7a45
udno change which broke something
Nov 4, 2025
690bfb3
'Updated version in __init__ to 2025.11.5.1
rhayes777 Nov 5, 2025
a77f7ef
minor changes
Nov 10, 2025
0fa1267
git push
Nov 10, 2025
ded081d
remove likelihood evaluation time no jax
Nov 10, 2025
8f65432
fix indentation causing plot bug
Nov 10, 2025
3328335
fix temporary latent bug
Nov 10, 2025
1ed2696
most jax imports cleaned up and moved
Nov 12, 2025
bffa8ca
all jax imports except wrapper and pytrees deferred
Nov 12, 2025
843b11b
remove samples_jax from initializer
Nov 13, 2025
0868c5e
remove use jax in config
Nov 13, 2025
0534220
fix bug with arrray allocation
Nov 13, 2025
27c6966
fix final unit test
Nov 13, 2025
df81d94
finish
Nov 13, 2025
c744db2
Merge pull request #1161 from rhayes777/feature/xp_no_autofit_import
Jammy2211 Nov 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions autofit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from autoconf import jax_wrapper
from autoconf.dictable import register_parser
from . import conf

Expand Down Expand Up @@ -140,6 +141,4 @@ def save_abc(pickler, obj):
pickle._Pickler.save_type(pickler, obj)




__version__ = "2025.10.20.1"
__version__ = "2025.11.5.1"
11 changes: 11 additions & 0 deletions autofit/aggregator/search_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,17 @@ def samples_summary(self) -> SamplesSummary:
summary.model = self.model
return summary

@property
def latent_summary(self) -> SamplesSummary:
"""
The summary of the samples, which includes the maximum log likelihood sample and the log evidence.

This is loaded from a JSON file.
"""
summary = self.value("latent.latent_summary")
summary.model = self.model
return summary

@property
def instance(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion autofit/aggregator/summary/aggregate_csv/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def __init__(self, name: str, compute: Callable):
self.compute = compute

def value(self, row: "Row"):

try:
return self.compute(row.result.samples)
return self.compute(row.result)
except AttributeError as e:
raise AssertionError(
"Cannot compute additional fields if no samples.json present"
Expand Down
2 changes: 2 additions & 0 deletions autofit/aggregator/summary/aggregate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def output_to_folder(
else:
output_name = name[i]

output_path = folder / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
image.save(folder / f"{output_name}.png")

@staticmethod
Expand Down
2 changes: 0 additions & 2 deletions autofit/config/general.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
jax:
use_jax: false # If True, PyAutoFit uses JAX internally, whereas False uses normal Numpy.
updates:
iterations_per_quick_update: 1e99 # Non-linear search iterations between every quick update, which just displays the maximum likelihood model fit.
iterations_per_full_update: 1e99 # Non-linear search iterations between every full update, which outputs all visuals and result fits (e.g. model.result, search.summary), this exits the search and can be slow.
Expand Down
15 changes: 7 additions & 8 deletions autofit/example/analysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
from typing import Dict, Optional

from autofit.jax_wrapper import numpy as xp

import autofit as af

from autofit.example.result import ResultExample
Expand Down Expand Up @@ -38,7 +36,7 @@ class Analysis(af.Analysis):

LATENT_KEYS = ["gaussian.fwhm"]

def __init__(self, data: np.ndarray, noise_map: np.ndarray):
def __init__(self, data: np.ndarray, noise_map: np.ndarray, use_jax=False):
"""
In this example the `Analysis` object only contains the data and noise-map. It can be easily extended,
for more complex data-sets and model fitting problems.
Expand All @@ -51,12 +49,12 @@ def __init__(self, data: np.ndarray, noise_map: np.ndarray):
A 1D numpy array containing the noise values of the data, used for computing the goodness of fit
metric.
"""
super().__init__()
super().__init__(use_jax=use_jax)

self.data = data
self.noise_map = noise_map

def log_likelihood_function(self, instance: af.ModelInstance) -> float:
def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
"""
Determine the log likelihood of a fit of multiple profiles to the dataset.

Expand Down Expand Up @@ -98,14 +96,15 @@ def model_data_1d_from(self, instance: af.ModelInstance) -> np.ndarray:
The model data of the profiles.
"""

xvalues = xp.arange(self.data.shape[0])
model_data_1d = xp.zeros(self.data.shape[0])
xvalues = self._xp.arange(self.data.shape[0])
model_data_1d = self._xp.zeros(self.data.shape[0])

try:
for profile in instance:
try:
model_data_1d += profile.model_data_from(
xvalues=xvalues
xvalues=xvalues,
xp=self._xp
)
except AttributeError:
pass
Expand Down
10 changes: 4 additions & 6 deletions autofit/example/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import numpy as np
from typing import Tuple

from autofit.jax_wrapper import numpy as xp

"""
The `Gaussian` class in this module is the model components that is fitted to data using a non-linear search. The
inputs of its __init__ constructor are the parameters which can be fitted for.
Expand Down Expand Up @@ -47,7 +45,7 @@ def fwhm(self) -> float:
the free parameters of the model which we are interested and may want to store the full samples information
on (e.g. to create posteriors).
"""
return 2 * xp.sqrt(2 * xp.log(2)) * self.sigma
return 2 * np.sqrt(2 * np.log(2)) * self.sigma

def _tree_flatten(self):
return (self.centre, self.normalization, self.sigma), None
Expand All @@ -64,7 +62,7 @@ def __eq__(self, other):
and self.sigma == other.sigma
)

def model_data_from(self, xvalues: np.ndarray) -> np.ndarray:
def model_data_from(self, xvalues: np.ndarray, xp=np) -> np.ndarray:
"""
Calculate the normalization of the profile on a 1D grid of Cartesian x coordinates.

Expand All @@ -82,7 +80,7 @@ def model_data_from(self, xvalues: np.ndarray) -> np.ndarray:
xp.exp(-0.5 * xp.square(xp.divide(transformed_xvalues, self.sigma))),
)

def f(self, x: float):
def f(self, x: float, xp=np):
return (
self.normalization
/ (self.sigma * xp.sqrt(2 * math.pi))
Expand Down Expand Up @@ -137,7 +135,7 @@ def __init__(
self.normalization = normalization
self.rate = rate

def model_data_from(self, xvalues: np.ndarray) -> np.ndarray:
def model_data_from(self, xvalues: np.ndarray, xp=np) -> np.ndarray:
"""
Calculate the 1D Gaussian profile on a 1D grid of Cartesian x coordinates.

Expand Down
4 changes: 3 additions & 1 deletion autofit/graphical/declarative/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ class AbstractDeclarativeFactor(Analysis, ABC):
optimiser: AbstractFactorOptimiser
_plates: Tuple[Plate, ...] = ()

def __init__(self, include_prior_factors=False):
def __init__(self, include_prior_factors=False, use_jax : bool = False):
self.include_prior_factors = include_prior_factors

super().__init__(use_jax=use_jax)

@property
@abstractmethod
def name(self):
Expand Down
18 changes: 15 additions & 3 deletions autofit/graphical/declarative/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
from autofit.mapper.model import ModelInstance
from autofit.mapper.prior_model.prior_model import Model

from autofit.jax_wrapper import register_pytree_node_class
from ...non_linear.combined_result import CombinedResult


@register_pytree_node_class
class FactorGraphModel(AbstractDeclarativeFactor):
def __init__(
self,
*model_factors: Union[AbstractDeclarativeFactor, HierarchicalFactor],
name=None,
include_prior_factors=True,
use_jax : bool = False
):
"""
A collection of factors that describe models, which can be
Expand All @@ -36,6 +34,7 @@ def __init__(
"""
super().__init__(
include_prior_factors=include_prior_factors,
use_jax=use_jax,
)
self._model_factors = list(model_factors)
self._name = name or namer(self.__class__.__name__)
Expand Down Expand Up @@ -279,3 +278,16 @@ def visualize_combined(
instance,
during_analysis=during_analysis,
)

def perform_quick_update(self, paths, instance):

try:
self.model_factors[0].visualize_combined(
analyses=self.model_factors,
paths=paths,
instance=instance,
during_analysis=True,
quick_update=True,
)
except Exception as e:
pass
4 changes: 0 additions & 4 deletions autofit/graphical/declarative/factor/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from autofit.non_linear.paths.abstract import AbstractPaths
from .abstract import AbstractModelFactor

from autofit.jax_wrapper import register_pytree_node_class


class FactorCallable:
def __init__(
Expand Down Expand Up @@ -45,8 +43,6 @@ def __call__(self, **kwargs: np.ndarray) -> float:
instance = self.prior_model.instance_for_arguments(arguments)
return self.analysis.log_likelihood_function(instance)


@register_pytree_node_class
class AnalysisFactor(AbstractModelFactor):
@property
def prior_model(self):
Expand Down
3 changes: 2 additions & 1 deletion autofit/graphical/declarative/factor/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __call__(self, **kwargs):

class _HierarchicalFactor(AbstractModelFactor):
def __init__(
self, distribution_model: HierarchicalFactor, drawn_prior: Prior,
self, distribution_model: HierarchicalFactor, drawn_prior: Prior, use_jax : bool = False
):
"""
A factor that links a variable to a parameterised distribution.
Expand All @@ -159,6 +159,7 @@ def __init__(
"""
self.distribution_model = distribution_model
self.drawn_prior = drawn_prior
self.use_jax = use_jax

prior_variable_dict = {prior.name: prior for prior in distribution_model.priors}

Expand Down
4 changes: 3 additions & 1 deletion autofit/graphical/factor_graphs/factor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from copy import deepcopy
from inspect import getfullargspec
import jax
from typing import Tuple, Dict, Any, Callable, Union, List, Optional, TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -285,6 +284,8 @@ def _set_jacobians(
numerical_jacobian=True,
jacfwd=True,
):
import jax

self._vjp = vjp
self._jacfwd = jacfwd
if vjp or factor_vjp:
Expand Down Expand Up @@ -327,6 +328,7 @@ def __call__(self, values: VariableData) -> FactorValue:
return self._cache[key]

def _jax_factor_vjp(self, *args) -> Tuple[Any, Callable]:
import jax
return jax.vjp(self._factor, *args)

_factor_vjp = _jax_factor_vjp
Expand Down
2 changes: 1 addition & 1 deletion autofit/graphical/laplace/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def take_quasi_newton_step(
) -> Tuple[Optional[float], OptimisationState]:
""" """
state.search_direction = search_direction(state, **(search_direction_kws or {}))
if state.search_direction.vecnorm(np.Inf) == 0:
if state.search_direction.vecnorm(np.inf) == 0:
# if gradient is zero then at maximum already
return 0.0, state

Expand Down
3 changes: 3 additions & 0 deletions autofit/interpolator/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
x: np.ndarray,
y: np.ndarray,
inverse_covariance_matrix: np.ndarray,
use_jax : bool = False
):
"""
An analysis class that describes a linear relationship between x and y, y = mx + c
Expand All @@ -30,6 +31,8 @@ def __init__(
The y values. This is a matrix comprising all the variables in the model at each x value
inverse_covariance_matrix
"""
super().__init__(use_jax=use_jax)

self.x = x
self.y = y
self.inverse_covariance_matrix = inverse_covariance_matrix
Expand Down
88 changes: 0 additions & 88 deletions autofit/jax_wrapper.py

This file was deleted.

Loading
Loading