diff --git a/autogalaxy/abstract_fit.py b/autogalaxy/abstract_fit.py index 262b7e34d..5f7223592 100644 --- a/autogalaxy/abstract_fit.py +++ b/autogalaxy/abstract_fit.py @@ -15,12 +15,7 @@ class AbstractFitInversion: - def __init__( - self, - model_obj, - settings_inversion: aa.SettingsInversion, - xp=np - ): + def __init__(self, model_obj, settings_inversion: aa.SettingsInversion, xp=np): """ An abstract fit object which fits to datasets (e.g. imaging, interferometer) inherit from. @@ -131,7 +126,9 @@ def linear_light_profile_intensity_dict( for i, light_profile in enumerate(linear_obj_func.light_profile_list): if self.use_jax: - linear_light_profile_intensity_dict[light_profile] = reconstruction[i] + linear_light_profile_intensity_dict[light_profile] = reconstruction[ + i + ] else: linear_light_profile_intensity_dict[light_profile] = float( reconstruction[i] diff --git a/autogalaxy/galaxy/galaxies.py b/autogalaxy/galaxy/galaxies.py index c21f5a863..0544899b5 100644 --- a/autogalaxy/galaxy/galaxies.py +++ b/autogalaxy/galaxy/galaxies.py @@ -119,7 +119,10 @@ def image_2d_from( ) def galaxy_image_2d_dict_from( - self, grid: aa.type.Grid2DLike, xp=np, operated_only: Optional[bool] = None + self, + grid: aa.type.Grid2DLike, + xp=np, + operated_only: Optional[bool] = None, ) -> {Galaxy: np.ndarray}: """ Returns a dictionary associating every `Galaxy` object with its corresponding 2D image, using the instance diff --git a/autogalaxy/imaging/fit_imaging.py b/autogalaxy/imaging/fit_imaging.py index 14e132f40..e4ea1bce7 100644 --- a/autogalaxy/imaging/fit_imaging.py +++ b/autogalaxy/imaging/fit_imaging.py @@ -78,7 +78,7 @@ def __init__( self=self, model_obj=self.galaxies, settings_inversion=settings_inversion, - xp=xp + xp=xp, ) self.adapt_images = adapt_images @@ -184,7 +184,7 @@ def galaxy_image_dict(self) -> Dict[Galaxy, np.ndarray]: """ galaxy_image_2d_dict = self.galaxies.galaxy_image_2d_dict_from( - grid=self.grids.lp, + grid=self.grids.lp, xp=self._xp ) galaxy_linear_obj_image_dict = self.galaxy_linear_obj_data_dict_from( @@ -212,6 +212,7 @@ def galaxy_model_image_dict(self) -> Dict[Galaxy, np.ndarray]: grid=self.grids.lp, psf=self.dataset.psf, blurring_grid=self.grids.blurring, + xp=self._xp, ) galaxy_linear_obj_image_dict = self.galaxy_linear_obj_data_dict_from( diff --git a/autogalaxy/interferometer/fit_interferometer.py b/autogalaxy/interferometer/fit_interferometer.py index 14473e479..72b4e0f05 100644 --- a/autogalaxy/interferometer/fit_interferometer.py +++ b/autogalaxy/interferometer/fit_interferometer.py @@ -72,7 +72,7 @@ def __init__( self=self, model_obj=self.galaxies, settings_inversion=settings_inversion, - xp=xp + xp=xp, ) self.adapt_images = adapt_images @@ -161,7 +161,9 @@ def galaxy_image_dict(self) -> Dict[Galaxy, np.ndarray]: For modeling, this dictionary is used to set up the `adapt_images` that adapt certain pixelizations to the data being fitted. """ - galaxy_image_dict = self.galaxies.galaxy_image_2d_dict_from(grid=self.grids.lp) + galaxy_image_dict = self.galaxies.galaxy_image_2d_dict_from( + grid=self.grids.lp, xp=self._xp + ) galaxy_linear_obj_image_dict = self.galaxy_linear_obj_data_dict_from( use_operated=False @@ -184,7 +186,7 @@ def galaxy_model_visibilities_dict(self) -> Dict[Galaxy, np.ndarray]: data being fitted. """ galaxy_model_visibilities_dict = self.galaxies.galaxy_visibilities_dict_from( - grid=self.grids.lp, transformer=self.dataset.transformer + grid=self.grids.lp, transformer=self.dataset.transformer, xp=self._xp ) galaxy_linear_obj_data_dict = self.galaxy_linear_obj_data_dict_from( diff --git a/autogalaxy/operate/image.py b/autogalaxy/operate/image.py index 38d4fbbc4..0973002b9 100644 --- a/autogalaxy/operate/image.py +++ b/autogalaxy/operate/image.py @@ -363,12 +363,12 @@ class OperateImageGalaxies(OperateImageList): """ def galaxy_image_2d_dict_from( - self, grid: aa.Grid2D, operated_only: Optional[bool] = None + self, grid: aa.Grid2D, xp=np, operated_only: Optional[bool] = None ) -> Dict[Galaxy, aa.Array2D]: raise NotImplementedError def galaxy_blurred_image_2d_dict_from( - self, grid, psf, blurring_grid + self, grid, psf, blurring_grid, xp=np ) -> Dict[Galaxy, aa.Array2D]: """ Evaluate the light object's dictionary mapping galaixes to their corresponding 2D images and convolve each @@ -392,15 +392,15 @@ def galaxy_blurred_image_2d_dict_from( """ galaxy_image_2d_not_operated_dict = self.galaxy_image_2d_dict_from( - grid=grid, operated_only=False + grid=grid, operated_only=False, xp=xp ) galaxy_blurring_image_2d_not_operated_dict = self.galaxy_image_2d_dict_from( - grid=blurring_grid, operated_only=False + grid=blurring_grid, operated_only=False, xp=xp ) galaxy_image_2d_operated_dict = self.galaxy_image_2d_dict_from( - grid=grid, operated_only=True + grid=grid, operated_only=True, xp=xp ) galaxy_blurred_image_2d_dict = {} @@ -414,6 +414,7 @@ def galaxy_blurred_image_2d_dict_from( blurred_image_2d = psf.convolved_image_from( image=image_2d_not_operated, blurring_image=blurring_image_2d_not_operated, + xp=xp, ) image_2d_operated = galaxy_image_2d_operated_dict[galaxy_key] @@ -451,7 +452,7 @@ def galaxy_visibilities_dict_from( in the uv-plane. """ - galaxy_image_2d_dict = self.galaxy_image_2d_dict_from(grid=grid) + galaxy_image_2d_dict = self.galaxy_image_2d_dict_from(grid=grid, xp=xp) galaxy_visibilities_dict = {}