Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 4 additions & 136 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PyAutoLens: Open-Source Strong Lensing
======================================
PyAutoLens-JAX: Open-Source Strong Lensing
==========================================

.. |nbsp| unicode:: 0xA0
:trim:
Expand Down Expand Up @@ -50,7 +50,7 @@ PyAutoLens: Open-Source Strong Lensing

When two or more galaxies are aligned perfectly down our line-of-sight, the background galaxy appears multiple times.

This is called strong gravitational lensing and **PyAutoLens** makes it simple to model strong gravitational lenses.
This is called strong gravitational lensing and **PyAutoLens** makes it **simple** to model strong gravitational lenses, using JAX to **accelerate lens modeling on GPUs**.

Getting Started
---------------
Expand Down Expand Up @@ -80,136 +80,4 @@ For users less familiar with gravitational lensing, Bayesian inference and scien
you may wish to read through the **HowToLens** lectures. These teach you the basic principles of gravitational lensing
and Bayesian inference, with the content pitched at undergraduate level and above.

A complete overview of the lectures `is provided on the HowToLens readthedocs page <https://pyautolens.readthedocs.io/en/latest/howtolens/howtolens.html>`_

API Overview
------------

Lensing calculations are performed in **PyAutoLens** by building a ``Tracer`` object from ``LightProfile``,
``MassProfile`` and ``Galaxy`` objects. We create a simple strong lens system where a redshift 0.5
lens ``Galaxy`` with an ``Isothermal`` ``MassProfile`` lenses a background source at redshift 1.0 with an
``Exponential`` ``LightProfile`` representing a disk.

.. code-block:: python

import autolens as al
import autolens.plot as aplt
from astropy import cosmology as cosmo

"""
To describe the deflection of light by mass, two-dimensional grids of (y,x) Cartesian
coordinates are used.
"""
grid = al.Grid2D.uniform(
shape_native=(50, 50),
pixel_scales=0.05, # <- Conversion from pixel units to arc-seconds.
)

"""
The lens galaxy has an elliptical isothermal mass profile and is at redshift 0.5.
"""
mass = al.mp.Isothermal(
centre=(0.0, 0.0), ell_comps=(0.1, 0.05), einstein_radius=1.6
)

lens_galaxy = al.Galaxy(redshift=0.5, mass=mass)

"""
The source galaxy has an elliptical exponential light profile and is at redshift 1.0.
"""
disk = al.lp.Exponential(
centre=(0.3, 0.2),
ell_comps=(0.05, 0.25),
intensity=0.05,
effective_radius=0.5,
)

source_galaxy = al.Galaxy(redshift=1.0, disk=disk)

"""
We create the strong lens using a Tracer, which uses the galaxies, their redshifts
and an input cosmology to determine how light is deflected on its path to Earth.
"""
tracer = al.Tracer(
galaxies=[lens_galaxy, source_galaxy],
cosmology = al.cosmo.Planck15()
)

"""
We can use the Grid2D and Tracer to perform many lensing calculations, for example
plotting the image of the lensed source.
"""
tracer_plotter = aplt.TracerPlotter(tracer=tracer, grid=grid)
tracer_plotter.figures_2d(image=True)

With **PyAutoLens**, you can begin modeling a lens in minutes. The example below demonstrates a simple analysis which
fits the lens galaxy's mass with an ``Isothermal`` and the source galaxy's light with a ``Sersic``.

.. code-block:: python

import autofit as af
import autolens as al
import autolens.plot as aplt

"""
Load Imaging data of the strong lens from the dataset folder of the workspace.
"""
dataset = al.Imaging.from_fits(
data_path="/path/to/dataset/image.fits",
noise_map_path="/path/to/dataset/noise_map.fits",
psf_path="/path/to/dataset/psf.fits",
pixel_scales=0.1,
)

"""
Create a mask for the imaging data, which we setup as a 3.0" circle, and apply it.
"""
mask = al.Mask2D.circular(
shape_native=dataset.shape_native,
pixel_scales=dataset.pixel_scales,
radius=3.0
)
dataset = dataset.apply_mask(mask=mask)

"""
We model the lens galaxy using an elliptical isothermal mass profile and
the source galaxy using an elliptical sersic light profile.

To setup these profiles as model components whose parameters are free & fitted for
we set up each Galaxy as a `Model` and define the model as a `Collection` of all galaxies.
"""
# Lens:

mass = af.Model(al.mp.Isothermal)
lens = af.Model(al.Galaxy, redshift=0.5, mass=lens_mass_profile)

# Source:

disk = af.Model(al.lp.Sersic)
source = af.Model(al.Galaxy, redshift=1.0, disk=disk)

# Overall Lens Model:
model = af.Collection(galaxies=af.Collection(lens=lens, source=source))

"""
We define the non-linear search used to fit the model to the data (in this case, Dynesty).
"""
search = af.Nautilus(name="search[example]", n_live=50)

"""
We next set up the `Analysis`, which contains the `log likelihood function` that the
non-linear search calls to fit the lens model to the data.
"""
analysis = al.AnalysisImaging(dataset=dataset)

"""
To perform the model-fit we pass the model and analysis to the search's fit method. This will
output results (e.g., dynesty samples, model parameters, visualization) to hard-disk.
"""
result = search.fit(model=model, analysis=analysis)

"""
The results contain information on the fit, for example the maximum likelihood
model from the Dynesty parameter space search.
"""
print(result.samples.max_log_likelihood())
A complete overview of the lectures `is provided on the HowToLens readthedocs page <https://pyautolens.readthedocs.io/en/latest/howtolens/howtolens.html>`_
9 changes: 0 additions & 9 deletions autolens/config/non_linear.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ nest:
slices: 5
update_interval: null
walks: 5
updates:
iterations_per_update: 2500
remove_state_files_at_end: true
DynestyStatic:
initialize:
method: prior
Expand All @@ -58,9 +55,3 @@ nest:
slices: 5
update_interval: null
walks: 5
updates:
iterations_per_update: 5000
log_every_update: 1
model_results_every_update: 1
remove_state_files_at_end: true
visualize_every_update: 1
2 changes: 2 additions & 0 deletions autolens/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,5 @@ def save_attributes(self, paths: af.DirectoryPaths):
)

analysis.save_attributes(paths=paths)


30 changes: 16 additions & 14 deletions autolens/imaging/model/plotter_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PlotterInterfaceImaging(PlotterInterface):
imaging_combined = AgPlotterInterfaceImaging.imaging_combined

def fit_imaging(
self, fit: FitImaging, visuals_2d_of_planes_list : Optional[aplt.Visuals2D] = None
self, fit: FitImaging, visuals_2d_of_planes_list : Optional[aplt.Visuals2D] = None, quick_update: bool = False
):
"""
Visualizes a `FitImaging` object, which fits an imaging dataset.
Expand All @@ -40,28 +40,18 @@ def fit_imaging(
The maximum log likelihood `FitImaging` of the non-linear search which is used to plot the fit.
"""

if plot_setting(section="tracer", name="subplot_tracer"):

mat_plot_2d = self.mat_plot_2d_from()

fit_plotter = FitImagingPlotter(
fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list,
)

fit_plotter.subplot_tracer()

def should_plot(name):
return plot_setting(section=["fit", "fit_imaging"], name=name)

mat_plot_2d = self.mat_plot_2d_from()
mat_plot_2d = self.mat_plot_2d_from(quick_update=quick_update)

fit_plotter = FitImagingPlotter(
fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list,
)

plane_indexes_to_plot = [i for i in fit.tracer.plane_indexes_with_images if i != 0]

if should_plot("subplot_fit"):
if should_plot("subplot_fit") or quick_update:

# This loop means that multiple subplot_fit objects are output for a double source plane lens.

Expand All @@ -71,6 +61,19 @@ def should_plot(name):
else:
fit_plotter.subplot_fit()

if quick_update:
return

if plot_setting(section="tracer", name="subplot_tracer"):

mat_plot_2d = self.mat_plot_2d_from()

fit_plotter = FitImagingPlotter(
fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list,
)

fit_plotter.subplot_tracer()

if should_plot("subplot_fit_log10"):

try:
Expand All @@ -82,7 +85,6 @@ def should_plot(name):
except ValueError:
pass


if should_plot("subplot_of_planes"):
fit_plotter.subplot_of_planes()

Expand Down
69 changes: 49 additions & 20 deletions autolens/imaging/model/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def visualize(
paths: af.DirectoryPaths,
instance: af.ModelInstance,
during_analysis: bool,
quick_update: bool = False,
):
"""
Output images of the maximum log likelihood model inferred by the model-fit. This function is called throughout
Expand Down Expand Up @@ -91,8 +92,56 @@ def visualize(
via a non-linear search).
"""

import time

start_time = time.time()

fit = analysis.fit_from(instance=instance)

print(f"Fit From time: {time.time() - start_time} seconds")

start_time = time.time()

tracer = fit.tracer_linear_light_profiles_to_light_profiles

print(f"Tracer Linear Light Profiles time: {time.time() - start_time} seconds")

start_time = time.time()

visuals_2d_of_planes_list = tracer_util.visuals_2d_of_planes_list_from(
tracer=fit.tracer,
grid=fit.grids.lp.mask.derive_grid.all_false
)

print(f"Visuals 2D of planes list time: {time.time() - start_time} seconds")

start_time = time.time()

plotter_interface = PlotterInterfaceImaging(
image_path=paths.image_path,
title_prefix=analysis.title_prefix,
)

print(f"Plotter Interface Imaging time: {time.time() - start_time} seconds")

start = time.time()

try:
plotter_interface.fit_imaging(
fit=fit,
visuals_2d_of_planes_list=visuals_2d_of_planes_list,
quick_update=quick_update,
)
except exc.InversionException:
pass

print(f"Plotter Interface Fit Imaging time: {time.time() - start} seconds")

if quick_update:
return

# Full update based on configs.

if analysis.positions_likelihood_list is not None:

overwrite_file = True
Expand All @@ -111,33 +160,13 @@ def visualize(
except exc.InversionException:
return

tracer = fit.tracer_linear_light_profiles_to_light_profiles

zoom = ag.Zoom2D(mask=fit.mask)

extent = zoom.extent_from(buffer=0)
shape_native = zoom.shape_native

grid = ag.Grid2D.from_extent(extent=extent, shape_native=shape_native)

visuals_2d_of_planes_list = tracer_util.visuals_2d_of_planes_list_from(
tracer=fit.tracer,
grid=fit.grids.lp.mask.derive_grid.all_false
)

plotter_interface = PlotterInterfaceImaging(
image_path=paths.image_path,
title_prefix=analysis.title_prefix,
)

try:
plotter_interface.fit_imaging(
fit=fit,
visuals_2d_of_planes_list=visuals_2d_of_planes_list
)
except exc.InversionException:
pass

plotter_interface.tracer(
tracer=tracer,
grid=grid,
Expand Down
2 changes: 1 addition & 1 deletion autolens/imaging/plot/fit_imaging_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def subplot_fit(self, plane_index: Optional[int] = None):
self.set_title(label=None)

self.mat_plot_2d.output.subplot_to_figure(
auto_filename=f"subplot_fit{plane_index_tag}"
auto_filename=f"subplot_fit{plane_index_tag}", also_show=self.mat_plot_2d.quick_update
)
self.close_subplot_figure()

Expand Down
2 changes: 1 addition & 1 deletion autolens/interferometer/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,4 @@ def save_attributes(self, paths: af.DirectoryPaths):
dataset=self.dataset,
)

analysis.save_attributes(paths=paths)
analysis.save_attributes(paths=paths)
4 changes: 4 additions & 0 deletions autolens/interferometer/model/plotter_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def fit_interferometer(
self,
fit: FitInterferometer,
visuals_2d_of_planes_list: Optional[aplt.Visuals2D] = None,
quick_update: bool = False,
):
"""
Visualizes a `FitInterferometer` object, which fits an interferometer dataset.
Expand Down Expand Up @@ -61,6 +62,9 @@ def should_plot(name):
if should_plot("subplot_fit_dirty_images"):
fit_plotter.subplot_fit_dirty_images()

if quick_update:
return

if should_plot("subplot_fit_real_space"):
fit_plotter.subplot_fit_real_space()

Expand Down
Loading
Loading