/
plotting.py
476 lines (407 loc) · 18.2 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Functions to plot light curve data and models."""
import numpy as np
from .bandpasses import get_bandpass
from .magsystems import get_magsystem
from .models import Model
from .photdata import photometric_data
from .utils import format_value
__all__ = ['plot_lc']
_model_ls = ['-', '--', ':', '-.']
def _add_errorbar(ax, x, y, yerr, filled, markersize=None, color=None):
"""Add an errorbar to Axes `ax`, allowing an array of markers."""
ax.errorbar(x[filled], y[filled], yerr[filled], ls='None',
marker='o', markersize=markersize, color=color)
notfilled = ~filled
ax.errorbar(x[notfilled], y[notfilled], yerr[notfilled], ls='None',
mfc='None', marker='o', markersize=markersize, color=color)
def _add_plot(ax, x, y, filled, markersize=None, color=None):
ax.plot(x[filled], y[filled], marker='o',
markersize=markersize, color=color, ls='None')
notfilled = ~filled
ax.plot(x[notfilled], y[notfilled], marker='o', mfc='None',
markersize=markersize, color=color, ls='None')
def plot_lc(data=None, model=None, bands=None, zp=25., zpsys='ab',
pulls=True, xfigsize=None, yfigsize=None, figtext=None,
model_label=None, errors=None, ncol=2, figtextsize=1.,
show_model_params=True, tighten_ylim=False, color=None,
cmap=None, cmap_lims=(3000., 10000.), fill_data_marker=None,
fname=None, fill_percentiles=None, **kwargs):
"""Plot light curve data or model light curves.
Parameters
----------
data : astropy `~astropy.table.Table` or similar, optional
Table of photometric data. Must include certain column names.
See the "Photometric Data" section of the documentation for required
columns.
model : `~sncosmo.Model` or list thereof, optional
If given, model light curve is plotted. If a string, the corresponding
model is fetched from the registry. If a list or tuple of
`~sncosmo.Model`, multiple models are plotted.
model_label : str or list, optional
If given, model(s) will be labeled in a legend in the upper left
subplot. Must be same length as model.
errors : dict, optional
Uncertainty on model parameters. If given, along with exactly one
model, uncertainty will be displayed with model parameters at the top
of the figure.
bands : list, optional
List of Bandpasses, or names thereof, to plot.
zp : float, optional
Zeropoint to normalize the flux in the plot (for the purpose of
plotting all observations on a common flux scale). Default is 25.
zpsys : str, optional
Zeropoint system to normalize the flux in the plot (for the purpose of
plotting all observations on a common flux scale).
Default is ``'ab'``.
pulls : bool, optional
If True (and if model and data are given), plot pulls. Pulls are the
deviation of the data from the model divided by the data uncertainty.
Default is ``True``.
figtext : str, optional
Text to add to top of figure. If a list of strings, each item is
placed in a separate "column". Use newline separators for multiple
lines.
ncol : int, optional
Number of columns of axes. Default is 2.
xfigsize, yfigsize : float, optional
figure size in inches in x or y. Specify one or the other, not both.
Default is to set axes panel size to 3.0 x 2.25 inches.
figtextsize : float, optional
Space to reserve at top of figure for figtext (if not None).
Default is 1 inch.
show_model_params : bool, optional
If there is exactly one model plotted, the parameters of the model
are added to ``figtext`` by default (as two additional columns) so
that they are printed at the top of the figure. Set this to False to
disable this behavior.
tighten_ylim : bool, optional
If true, tighten the y limits so that the model is visible (if any
models are plotted).
color : str or mpl_color, optional
Color of data and model lines in each band. Can be any type of color
that matplotlib understands. If None (default) a colormap will be used
to choose a color for each band according to its central wavelength.
cmap : Colormap, optional
A matplotlib colormap to use, if color is None. If both color
and cmap are None, a default colormap will be used.
cmap_lims : (float, float), optional
The wavelength limits for the colormap, in Angstroms. Default is
(3000., 10000.), meaning that a bandpass with a central wavelength of
3000 Angstroms will be assigned a color at the low end of the colormap
and a bandpass with a central wavelength of 10000 will be assigned a
color at the high end of the colormap.
fill_data_marker : array_like, optional
Array of booleans indicating whether to plot a filled or unfilled
marker for each data point. Default is all filled markers.
fname : str, optional
Filename to pass to savefig. If None (default), figure is returned.
fill_percentiles : (float, float, float), optional
When multiple models are given, the percentiles for a light
curve confidence interval. The upper and lower perceniles
define a fill between region, and the middle percentile
defines a line that will be plotted over the fill between
region.
kwargs : optional
Any additional keyword args are passed to `~matplotlib.pyplot.savefig`.
Popular options include ``dpi``, ``format``, ``transparent``. See
matplotlib docs for full list.
Returns
-------
fig : matplotlib `~matplotlib.figure.Figure`
Only returned if `fname` is `None`. Display to screen with
``plt.show()`` or save with ``fig.savefig(filename)``. When creating
many figures, be sure to close with ``plt.close(fig)``.
Examples
--------
>>> import sncosmo
>>> import matplotlib.pyplot as plt
Load some example data:
>>> data = sncosmo.load_example_data()
Plot the data, displaying to the screen:
>>> fig = plot_lc(data)
>>> plt.show()
Plot a model along with the data:
>>> model = sncosmo.Model('salt2')
>>> model.set(z=0.5, c=0.2, t0=55100., x0=1.547e-5)
>>> sncosmo.plot_lc(data, model=model)
.. image:: /pyplots/plotlc_example.png
Plot just the model, for selected bands:
>>> sncosmo.plot_lc(model=model,
... bands=['sdssg', 'sdssr'])
Plot figures on a multipage pdf:
>>> from matplotlib.backends.backend_pdf import PdfPages
>>> pp = PdfPages('output.pdf')
>>> # Do the following as many times as you like:
>>> sncosmo.plot_lc(data, fname=pp, format='pdf')
>>> # Don't forget to close at the end:
>>> pp.close()
"""
from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib.ticker import MaxNLocator
from mpl_toolkits.axes_grid1 import make_axes_locatable
if data is None and model is None:
raise ValueError('must specify at least one of: data, model')
if data is None and bands is None:
raise ValueError('must specify bands to plot for model(s)')
# Get the model(s).
if model is None:
models = []
elif isinstance(model, (tuple, list)):
models = model
else:
models = [model]
if not all([isinstance(m, Model) for m in models]):
raise TypeError('model(s) must be Model instance(s)')
# Get the model labels
if model_label is None:
model_labels = [None] * len(models)
elif isinstance(model_label, str):
model_labels = [model_label]
else:
model_labels = model_label
if len(model_labels) != len(models):
raise ValueError('if given, length of model_label must match '
'that of model')
# Color options.
if color is None:
if cmap is None:
cmap = cm.get_cmap('jet_r')
# Standardize and normalize data.
if data is not None:
data = photometric_data(data)
data = data.normalized(zp=zp, zpsys=zpsys)
if not np.all(np.ediff1d(data.time) >= 0.0):
sortidx = np.argsort(data.time)
data = data[sortidx]
else:
sortidx = None
# Bands to plot
if data is None:
bands = set(bands)
elif bands is None:
bands = set(data.band)
else:
bands = set(data.band) & set(bands)
# ensure bands is a list of Bandpass objects
bands = [get_bandpass(b) for b in bands]
# filled: used only if data is not None. Guarantee array of booleans
if data is not None:
if fill_data_marker is None:
fill_data_marker = np.ones(data.time.shape, dtype=bool)
else:
fill_data_marker = np.asarray(fill_data_marker)
if fill_data_marker.shape != data.time.shape:
raise ValueError("fill_data_marker shape does not match data")
if sortidx is not None: # sort like we sorted the data
fill_data_marker = fill_data_marker[sortidx]
# Build figtext (including model parameters, if there is exactly 1 model).
if errors is None:
errors = {}
if figtext is None:
figtext = []
elif isinstance(figtext, str):
figtext = [figtext]
if len(models) == 1 and show_model_params:
model = models[0]
lines = []
for i in range(len(model.param_names)):
name = model.param_names[i]
lname = model.param_names_latex[i]
v = format_value(model.parameters[i], errors.get(name), latex=True)
lines.append('${0} = {1}$'.format(lname, v))
# Split lines into two columns.
n = len(model.param_names) - len(model.param_names) // 2
figtext.append('\n'.join(lines[:n]))
figtext.append('\n'.join(lines[n:]))
if len(figtext) == 0:
figtextsize = 0.
# Calculate layout of figure (columns, rows, figure size). We have to
# calculate these explicitly because plt.tight_layout() doesn't space the
# subplots as we'd like them when only some of them have xlabels/xticks.
wspace = 0.6 # All in inches.
hspace = 0.3
lspace = 1.0
bspace = 0.7
trspace = 0.2
nrow = (len(bands) - 1) // ncol + 1
if xfigsize is None and yfigsize is None:
hpanel = 2.25
wpanel = 3.
elif xfigsize is None:
hpanel = (yfigsize - figtextsize - bspace - trspace -
hspace * (nrow - 1)) / nrow
wpanel = hpanel * 3. / 2.25
elif yfigsize is None:
wpanel = (xfigsize - lspace - trspace - wspace * (ncol - 1)) / ncol
hpanel = wpanel * 2.25 / 3.
else:
raise ValueError('cannot specify both xfigsize and yfigsize')
figsize = (lspace + wpanel * ncol + wspace * (ncol - 1) + trspace,
bspace + hpanel * nrow + hspace * (nrow - 1) + trspace +
figtextsize)
# Create the figure and axes.
fig, axes = plt.subplots(nrow, ncol, figsize=figsize, squeeze=False)
fig.subplots_adjust(left=lspace / figsize[0],
bottom=bspace / figsize[1],
right=1. - trspace / figsize[0],
top=1. - (figtextsize + trspace) / figsize[1],
wspace=wspace / wpanel,
hspace=hspace / hpanel)
# Write figtext at the top of the figure.
for i, coltext in enumerate(figtext):
if coltext is not None:
xpos = (trspace / figsize[0] +
(1. - 2.*trspace/figsize[0]) * (i/len(figtext)))
ypos = 1. - trspace / figsize[1]
fig.text(xpos, ypos, coltext, va="top", ha="left",
multialignment="left")
# If there is exactly one model, offset the time axis by the model's t0.
if len(models) == 1 and data is not None:
toff = models[0].parameters[1]
else:
toff = 0.
# Global min and max of time axis.
tmin, tmax = [], []
if data is not None:
tmin.append(np.min(data.time) - 10.)
tmax.append(np.max(data.time) + 10.)
for model in models:
tmin.append(model.mintime())
tmax.append(model.maxtime())
tmin = min(tmin)
tmax = max(tmax)
tgrid = np.linspace(tmin, tmax, int(tmax - tmin) + 1)
# Loop over bands
waves = [b.wave_eff for b in bands]
waves_and_bands = sorted(zip(waves, bands))
for axnum in range(ncol * nrow):
row = axnum // ncol
col = axnum % ncol
ax = axes[row, col]
if axnum >= len(waves_and_bands):
ax.set_visible(False)
ax.set_frame_on(False)
continue
wave, band = waves_and_bands[axnum]
bandname_coords = (0.92, 0.92)
bandname_ha = 'right'
if color is None:
bandcolor = cmap((cmap_lims[1] - wave) /
(cmap_lims[1] - cmap_lims[0]))
else:
bandcolor = color
# Plot data if there are any.
if data is not None:
mask = data.band == band
time = data.time[mask]
flux = data.flux[mask]
fluxerr = data.fluxerr[mask]
bandfilled = fill_data_marker[mask]
_add_errorbar(ax, time - toff, flux, fluxerr, bandfilled,
color=bandcolor, markersize=3.)
# Plot model(s) if there are any.
lines = []
labels = []
mflux_ranges = []
mfluxes = []
plotci = len(models) > 1 and fill_percentiles is not None
for i, model in enumerate(models):
if model.bandoverlap(band):
mflux = model.bandflux(band, tgrid, zp=zp, zpsys=zpsys)
if not plotci:
mflux_ranges.append((mflux.min(), mflux.max()))
l, = ax.plot(tgrid - toff, mflux,
ls=_model_ls[i % len(_model_ls)],
marker='None', color=bandcolor)
lines.append(l)
else:
mfluxes.append(mflux)
else:
# Add a dummy line so the legend displays all models in the
# first panel.
lines.append(plt.Line2D([0, 1], [0, 1],
ls=_model_ls[i % len(_model_ls)],
marker='None', color=bandcolor))
labels.append(model_labels[i])
if plotci:
lo, med, up = np.percentile(mfluxes, fill_percentiles, axis=0)
l, = ax.plot(tgrid - toff, med, marker='None',
color=bandcolor)
lines.append(l)
ax.fill_between(tgrid - toff, lo, up, color=bandcolor,
alpha=0.4)
# Add a legend, if this is the first axes and there are two
# or more models to distinguish between.
if row == 0 and col == 0 and model_label is not None:
leg = ax.legend(lines, labels, loc='upper right',
fontsize='small', frameon=True)
bandname_coords = (0.08, 0.92) # Move bandname to upper left
bandname_ha = 'left'
# Band name in corner
text = band.name if band.name is not None else str(band)
ax.text(bandname_coords[0], bandname_coords[1], text,
color='k', ha=bandname_ha, va='top', transform=ax.transAxes)
ax.axhline(y=0., ls='--', c='k') # horizontal line at flux = 0.
ax.set_xlim((tmin-toff, tmax-toff))
# If we plotted any models, narrow axes limits so that the model
# is visible.
if tighten_ylim and len(mflux_ranges) > 0:
mfluxmin = min([r[0] for r in mflux_ranges])
mfluxmax = max([r[1] for r in mflux_ranges])
ymin, ymax = ax.get_ylim()
ymax = min(ymax, 4. * mfluxmax)
ymin = max(ymin, mfluxmin - (ymax - mfluxmax))
ax.set_ylim(ymin, ymax)
if col == 0:
ax.set_ylabel('flux ($ZP_{{{0}}} = {1}$)'
.format(get_magsystem(zpsys).name.upper(), zp))
show_pulls = (pulls and
data is not None and
len(models) == 1 and models[0].bandoverlap(band))
# steal part of the axes and plot pulls
if show_pulls:
divider = make_axes_locatable(ax)
axpulls = divider.append_axes('bottom', size='30%', pad=0.15,
sharex=ax)
mflux = models[0].bandflux(band, time, zp=zp, zpsys=zpsys)
fluxpulls = (flux - mflux) / fluxerr
axpulls.axhspan(ymin=-1., ymax=1., color='0.95')
axpulls.axhline(y=0., color=bandcolor)
_add_plot(axpulls, time - toff, fluxpulls, bandfilled,
markersize=4., color=bandcolor)
# Ensure y range is centered at 0.
ymin, ymax = axpulls.get_ylim()
absymax = max(abs(ymin), abs(ymax))
axpulls.set_ylim((-absymax, absymax))
# Set x limits to global values.
axpulls.set_xlim((tmin-toff, tmax-toff))
# Set small number of y ticks so tick labels don't overlap.
axpulls.yaxis.set_major_locator(MaxNLocator(5))
# Label the y axis and make sure ylabels align between axes.
if col == 0:
axpulls.set_ylabel('pull')
axpulls.yaxis.set_label_coords(-0.75 * lspace / wpanel, 0.5)
ax.yaxis.set_label_coords(-0.75 * lspace / wpanel, 0.5)
# Set top axis ticks invisible
for l in ax.get_xticklabels():
l.set_visible(False)
# Set ax to axpulls in order to adjust plots.
bottomax = axpulls
else:
bottomax = ax
# If this axes is one of the last `ncol`, set x label.
# Otherwise don't show tick labels.
if (len(bands) - axnum - 1) < ncol:
if toff == 0.:
bottomax.set_xlabel('time')
else:
bottomax.set_xlabel('time - {0:.2f}'.format(toff))
else:
for l in bottomax.get_xticklabels():
l.set_visible(False)
if fname is None:
return fig
plt.savefig(fname, **kwargs)
plt.close()