Skip to content

Commit

Permalink
Merge pull request #273 from Jammy2211/feature/mapping_plot
Browse files Browse the repository at this point in the history
Feature/mapping plot
  • Loading branch information
Jammy2211 committed May 14, 2024
2 parents 3abd8b7 + bd39391 commit 3f7cee8
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 7 deletions.
6 changes: 3 additions & 3 deletions autolens/analysis/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,6 @@ def source_plane_inversion_centre(self) -> aa.Grid2DIrregular:
"""
if self.max_log_likelihood_fit.inversion is not None:
if self.max_log_likelihood_fit.inversion.has(cls=aa.AbstractMapper):
return self.max_log_likelihood_fit.inversion.brightest_reconstruction_pixel_centre_list[
0
]
return (
self.max_log_likelihood_fit.inversion.brightest_pixel_centre_list[0]
)
1 change: 1 addition & 0 deletions autolens/config/visualize/plots.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
potential: false
inversion: # Settings for plots of inversions (e.g. InversionPlotter).
subplot_inversion: true # Plot subplot of all quantities in each inversion (e.g. reconstrucuted image, reconstruction)?
subplot_mappings: true # Plot subplot of the image-to-source pixels mappings of each pixelization?
all_at_end_png: true # Plot all individual plots listed below as .png (even if False)?
all_at_end_fits: true # Plot all individual plots listed below as .fits (even if False)?
all_at_end_pdf: false # Plot all individual plots listed below as publication-quality .pdf (even if False)?
Expand Down
3 changes: 3 additions & 0 deletions autolens/imaging/model/plotter_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def should_plot(name):
if should_plot("subplot_of_planes"):
fit_plotter.subplot_of_planes()

if plot_setting(section="inversion", name="subplot_mappings"):
fit_plotter.subplot_mappings_of_plane(plane_index=len(fit.tracer.planes) - 1)

if not during_analysis and should_plot("all_at_end_png"):

mat_plot_2d = self.mat_plot_2d_from(
Expand Down
64 changes: 64 additions & 0 deletions autolens/imaging/plot/fit_imaging_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np
from typing import Optional

from autoconf import conf

import autoarray as aa
import autogalaxy.plot as aplt

Expand Down Expand Up @@ -584,6 +586,68 @@ def subplot_tracer(self):
self.include_2d._radial_critical_curves = include_radial_critical_curves_original
self.mat_plot_2d.use_log10 = use_log10_original

def subplot_mappings_of_plane(self, plane_index: Optional[int] = None, auto_filename: str = "subplot_mappings"):

plane_indexes = self.plane_indexes_from(plane_index=plane_index)

for plane_index in plane_indexes:

pixelization_index = 0

inversion_plotter = self.inversion_plotter_of_plane(plane_index=0)

inversion_plotter.open_subplot_figure(number_subplots=4)

inversion_plotter.figures_2d_of_pixelization(
pixelization_index=pixelization_index, data_subtracted=True
)

total_pixels = conf.instance["visualize"]["general"]["inversion"][
"total_mappings_pixels"
]

pix_indexes = inversion_plotter.inversion.brightest_pixel_list_from(
total_pixels=total_pixels, filter_neighbors=True
)

inversion_plotter.visuals_2d.pix_indexes = [
[index] for index in pix_indexes[pixelization_index]
]

inversion_plotter.visuals_2d.tangential_critical_curves = None
inversion_plotter.visuals_2d.radial_critical_curves = None

inversion_plotter.figures_2d_of_pixelization(
pixelization_index=pixelization_index, reconstructed_image=True
)

self.visuals_2d.pix_indexes = [
[index] for index in pix_indexes[pixelization_index]
]

self.figures_2d_of_planes(
plane_index=plane_index,
plane_image=True,
use_source_vmax=True
)

self.set_title(label="Source Reconstruction (Unzoomed)")
self.figures_2d_of_planes(
plane_index=plane_index,
plane_image=True,
zoom_to_brightest=False,
use_source_vmax=True
)
self.set_title(label=None)

self.visuals_2d.pix_indexes = None

inversion_plotter.mat_plot_2d.output.subplot_to_figure(
auto_filename=f"{auto_filename}_{pixelization_index}"
)

inversion_plotter.close_subplot_figure()

def figures_2d(
self,
data: bool = False,
Expand Down
3 changes: 3 additions & 0 deletions autolens/interferometer/model/plotter_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def should_plot(name):
dirty_chi_squared_map=should_plot("chi_squared_map"),
)

if plot_setting(section="inversion", name="subplot_mappings"):
fit_plotter.subplot_mappings_of_plane(plane_index=len(fit.tracer.planes) - 1)

if not during_analysis and should_plot("all_at_end_png"):
mat_plot_1d = self.mat_plot_1d_from(subfolders=path.join(subfolders, "end"))
mat_plot_2d = self.mat_plot_2d_from(subfolders=path.join(subfolders, "end"))
Expand Down
81 changes: 81 additions & 0 deletions autolens/interferometer/plot/fit_interferometer_plotters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional

from autoconf import conf

import autoarray as aa
import autogalaxy.plot as aplt

Expand Down Expand Up @@ -88,6 +90,25 @@ def __init__(
def get_visuals_2d_real_space(self) -> aplt.Visuals2D:
return self.get_2d.via_mask_from(mask=self.fit.dataset.real_space_mask)

def plane_indexes_from(self, plane_index: int):
"""
Returns a list of all indexes of the planes in the fit, which is iterated over in figures that plot
individual figures of each plane in a tracer.
Parameters
----------
plane_index
A specific plane index which when input means that only a single plane index is returned.
Returns
-------
list
A list of galaxy indexes corresponding to planes in the plane.
"""
if plane_index is None:
return range(len(self.fit.tracer.planes))
return [plane_index]

@property
def tracer(self) -> Tracer:
return self.fit.tracer_linear_light_profiles_to_light_profiles
Expand Down Expand Up @@ -187,6 +208,66 @@ def subplot_fit(self):
self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_fit")
self.close_subplot_figure()

def subplot_mappings_of_plane(self, plane_index: Optional[int] = None, auto_filename: str = "subplot_mappings"):

if self.fit.inversion is None:
return

plane_indexes = self.plane_indexes_from(plane_index=plane_index)

for plane_index in plane_indexes:
pixelization_index = 0

inversion_plotter = self.inversion_plotter_of_plane(plane_index=0)

inversion_plotter.open_subplot_figure(number_subplots=4)

self.figures_2d(dirty_image=True)

total_pixels = conf.instance["visualize"]["general"]["inversion"][
"total_mappings_pixels"
]

pix_indexes = inversion_plotter.inversion.brightest_pixel_list_from(
total_pixels=total_pixels, filter_neighbors=True
)

inversion_plotter.visuals_2d.pix_indexes = [
[index] for index in pix_indexes[pixelization_index]
]

inversion_plotter.visuals_2d.tangential_critical_curves = None
inversion_plotter.visuals_2d.radial_critical_curves = None

inversion_plotter.figures_2d_of_pixelization(
pixelization_index=pixelization_index, reconstructed_image=True
)

self.visuals_2d.pix_indexes = [
[index] for index in pix_indexes[pixelization_index]
]

self.figures_2d_of_planes(
plane_index=plane_index,
plane_image=True,
)

self.set_title(label="Source Reconstruction (Unzoomed)")
self.figures_2d_of_planes(
plane_index=plane_index,
plane_image=True,
zoom_to_brightest=False,
)
self.set_title(label=None)

self.visuals_2d.pix_indexes = None

inversion_plotter.mat_plot_2d.output.subplot_to_figure(
auto_filename=f"{auto_filename}_{pixelization_index}"
)

inversion_plotter.close_subplot_figure()

def figures_2d(
self,
data: bool = False,
Expand Down
6 changes: 2 additions & 4 deletions test_autolens/analysis/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,9 @@ def test__source_plane_inversion_centre(analysis_imaging_7x7):

assert (
result.source_plane_inversion_centre.in_list[0]
== result.max_log_likelihood_fit.inversion.brightest_reconstruction_pixel_centre_list[
== result.max_log_likelihood_fit.inversion.brightest_pixel_centre_list[
0
].in_list[
0
]
].in_list[0]
)

lens = al.Galaxy(redshift=0.5, light=al.lp.SersicSph(intensity=1.0))
Expand Down

0 comments on commit 3f7cee8

Please sign in to comment.