Skip to content

Commit

Permalink
Merge pull request #260 from teutoburg/fh/plots-and-reprs
Browse files Browse the repository at this point in the history
Some plotting and representation improvements
  • Loading branch information
teutoburg committed Aug 9, 2023
2 parents c700bc4 + 331e12a commit 28dc5dc
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 93 deletions.
15 changes: 15 additions & 0 deletions scopesim/detector/detector_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ def readout(self, image_planes, array_effects=[], dtcr_effects=[], **kwargs):

return self.latest_exposure

def __repr__(self):
msg = (f"{self.__class__.__name__}"
f"({self.detector_list!r}, **{self.meta!r})")
return msg

def __str__(self):
return f"{self.__class__.__name__} with {self.detector_list!s}"

def _repr_pretty_(self, p, cycle):
"""For ipython"""
if cycle:
p.text(f"{self.__class__.__name__}(...)")
else:
p.text(str(self))


def make_primary_hdu(meta):
"""Create the primary header from meta data"""
Expand Down
23 changes: 12 additions & 11 deletions scopesim/effects/detector_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from astropy import units as u
from astropy.table import Table

from matplotlib import pyplot as plt

from ..base_classes import FOVSetupBase
from .effects import Effect
from .apertures import ApertureMask
Expand Down Expand Up @@ -251,22 +253,21 @@ def detector_headers(self, ids=None):

return hdrs

def plot(self):
import matplotlib.pyplot as plt
plt.gcf().clf()
def plot(self, axes=None):
if axes is None:
_, axes = plt.subplots()

for hdr in self.detector_headers():
x_mm, y_mm = calc_footprint(hdr, "D")
x_cen, y_cen = np.average(x_mm), np.average(y_mm)
x_mm = list(x_mm) + [x_mm[0]]
y_mm = list(y_mm) + [y_mm[0]]
plt.gca().plot(x_mm, y_mm)
plt.gca().text(x_cen, y_cen, hdr["ID"])
axes.plot(np.append(x_mm, x_mm[0]), np.append(y_mm, y_mm[0]))
axes.text(*np.mean((x_mm, y_mm), axis=1), hdr["ID"],
ha="center", va="center")

plt.gca().set_aspect("equal")
plt.ylabel("Size [mm]")
axes.set_aspect("equal")
axes.set_xlabel("Size [mm]")
axes.set_ylabel("Size [mm]")

return plt.gcf()
return axes


class DetectorWindow(DetectorList):
Expand Down
40 changes: 20 additions & 20 deletions scopesim/effects/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,12 @@ def display_name(self):

@property
def meta_string(self):
meta_str = ""
max_key_len = max(len(key) for key in self.meta.keys())
padlen = max_key_len + 4
for key in self.meta:
if key not in {"comments", "changes", "description", "history",
"report_table_caption", "report_plot_caption",
"table"}:
meta_str += f"{key:>{padlen}} : {self.meta[key]}\n"

padlen = 4 + len(max(self.meta, key=len))
exclude = {"comments", "changes", "description", "history",
"report_table_caption", "report_plot_caption", "table"}
meta_str = "\n".join(f"{key:>{padlen}} : {value}"
for key, value in self.meta.items()
if key not in exclude)
return meta_str

def report(self, filename=None, output="rst", rst_title_chars="*+",
Expand Down Expand Up @@ -199,7 +196,7 @@ def report(self, filename=None, output="rst", rst_title_chars="*+",
"""
changes = self.meta.get("changes", [])
changes_str = "- " + "\n- ".join([str(entry) for entry in changes])
changes_str = "- " + "\n- ".join(str(entry) for entry in changes)
cls_doc = self.__doc__ if self.__doc__ is not None else "<no docstring>"
cls_descr = cls_doc.lstrip().splitlines()[0]

Expand Down Expand Up @@ -239,7 +236,12 @@ def report(self, filename=None, output="rst", rst_title_chars="*+",
"""

if params["report_plot_include"] and hasattr(self, "plot"):
from matplotlib.figure import Figure
fig = self.plot()
# HACK: plot methods should always return the same, while this is
# not sorted out, deal with both fig and ax
if not isinstance(fig, Figure):
fig = fig.figure

if fig is not None:
path = params["report_image_path"]
Expand All @@ -257,6 +259,9 @@ def report(self, filename=None, output="rst", rst_title_chars="*+",
# params["report_rst_path"])
# rel_file_path = os.path.join(rel_path, fname)

# TODO: fname is set in a loop above, so using it here in the
# fstring will only access the last value from the loop,
# is that intended?
rst_str += f"""
.. figure:: {fname}
:name: {"fig:" + params.get("name", "<unknown Effect>")}
Expand Down Expand Up @@ -287,16 +292,11 @@ def report(self, filename=None, output="rst", rst_title_chars="*+",
return rst_str

def info(self):
"""
Prints basic information on the effect, notably the description
"""
text = str(self)

desc = self.meta.get("description")
if desc is not None:
text += f"\nDescription: {desc}"

print(text)
"""Print basic information on the effect, notably the description."""
if (desc := self.meta.get("description")) is not None:
print(f"{self}\nDescription: {desc}")
else:
print(self)

def __repr__(self):
return f"{self.__class__.__name__}(**{self.meta!r})"
Expand Down
58 changes: 39 additions & 19 deletions scopesim/effects/spectral_trace_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from pathlib import Path
import logging
from itertools import cycle

import numpy as np
from matplotlib import pyplot as plt

from astropy.io import fits
from astropy.table import Table
Expand Down Expand Up @@ -179,7 +181,7 @@ def apply_to(self, obj, **kwargs):
ex_vol["meta"].update(vol)
ex_vol["meta"].pop("wave_min")
ex_vol["meta"].pop("wave_max")
new_vols_list += extracted_vols
new_vols_list.extend(extracted_vols)

obj.volumes = new_vols_list

Expand All @@ -195,8 +197,7 @@ def apply_to(self, obj, **kwargs):
logging.info("Making cube")
obj.cube = obj.make_cube_hdu()

trace_id = obj.meta["trace_id"]
spt = self.spectral_traces[trace_id]
spt = self.spectral_traces[obj.meta["trace_id"]]
obj.hdu = spt.map_spectra_to_focal_plane(obj)

return obj
Expand All @@ -208,11 +209,11 @@ def footprint(self):
xfoot, yfoot = [], []
for spt in self.spectral_traces.values():
xtrace, ytrace = spt.footprint()
xfoot += xtrace
yfoot += ytrace
xfoot.extend(xtrace)
yfoot.extend(ytrace)

xfoot = [np.min(xfoot), np.max(xfoot), np.max(xfoot), np.min(xfoot)]
yfoot = [np.min(yfoot), np.min(yfoot), np.max(yfoot), np.max(yfoot)]
xfoot = [min(xfoot), max(xfoot), max(xfoot), min(xfoot)]
yfoot = [min(yfoot), min(yfoot), max(yfoot), max(yfoot)]

return xfoot, yfoot

Expand Down Expand Up @@ -299,42 +300,61 @@ def rectify_traces(self, hdulist, xi_min=None, xi_max=None, interps=None,
#pdu.header['FILTER'] = from_currsys("!OBS.filter_name_fw1")
outhdul = fits.HDUList([pdu])

for i, trace_id in enumerate(self.spectral_traces):
for i, trace_id in enumerate(self.spectral_traces, start=1):
hdu = self[trace_id].rectify(hdulist,
interps=interps,
bin_width=bin_width,
xi_min=xi_min, xi_max=xi_max,
wave_min=wave_min, wave_max=wave_max)
if hdu is not None: # ..todo: rectify does not do that yet
outhdul.append(hdu)
outhdul[0].header[f"EXTNAME{i+1}"] = trace_id
outhdul[0].header[f"EXTNAME{i}"] = trace_id

outhdul[0].header.update(inhdul[0].header)

return outhdul


def rectify_cube(self, hdulist):
"""Rectify traces and combine into a cube"""
raise(NotImplementedError)

def plot(self, wave_min=None, wave_max=None, **kwargs):
def plot(self, wave_min=None, wave_max=None, axes=None, **kwargs):
"""Plot every spectral trace in the spectral trace list.
Parameters
----------
wave_min : float, optional
Minimum wavelength, if any. If None, value from_currsys is used.
wave_max : float, optional
Maximum wavelength, if any. If None, value from_currsys is used.
axes : matplotlib axes, optional
The axes object to use for the plot. If None (default), a new
figure with one axes will be created.
**kwargs : dict
Any other parameters passed along to the plot method of the
individual spectral traces.
Returns
-------
fig : matplotlib figure
DESCRIPTION.
"""
if wave_min is None:
wave_min = from_currsys("!SIM.spectral.wave_min")
if wave_max is None:
wave_max = from_currsys("!SIM.spectral.wave_max")

from matplotlib import pyplot as plt
from matplotlib._pylab_helpers import Gcf
if len(Gcf.figs) == 0:
plt.figure(figsize=(12, 12))
if axes is None:
fig, axes = plt.subplots(figsize=(12, 12))
else:
fig = axes.figure

if self.spectral_traces is not None:
clrs = "rgbcymk" * (1 + len(self.spectral_traces) // 7)
for spt, c in zip(self.spectral_traces.values(), clrs):
spt.plot(wave_min, wave_max, c=c)
for spt, c in zip(self.spectral_traces.values(), cycle("rgbcymk")):
spt.plot(wave_min, wave_max, c=c, axes=axes, **kwargs)

return plt.gcf()
return fig

def __repr__(self):
# "\n".join([spt.__repr__() for spt in self.spectral_traces])
Expand Down

0 comments on commit 28dc5dc

Please sign in to comment.