Skip to content

Commit

Permalink
Generalized the plot_spectrum() method so it's not BDNYCdb specific. …
Browse files Browse the repository at this point in the history
…Also removed *plot argument from inventory() method since it i superfluous, slow, and a little brittle.
  • Loading branch information
hover2pi committed Mar 15, 2016
1 parent 21f0545 commit 77e502e
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions astrodbkit/astrodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,18 +327,18 @@ def _compare_records(self, table, duplicate, options=['r','c','k','sql']):
else:
print "\nInvalid command: {}\nTry again or type 'help' or 'abort'.\n".format(replace)

def inventory(self, source_id, plot=False, fetch=False, fmt='table'):
def inventory(self, source_id, fetch=False, fmt='table'):
"""
Prints a summary of all objects in the database. Input string or list of strings in **ID** or **unum** for specific objects.
Parameters
----------
source_id: int
The id from the SOURCES table whose data across all tables is to be printed.
plot: bool
Plots all spectra for the object.
fetch: bool
Return the results.
fmt: str
Returns the data as a dictionary, array, or astropy.table given 'dict', 'array', or 'table'
Returns
-------
Expand Down Expand Up @@ -378,14 +378,7 @@ def inventory(self, source_id, plot=False, fetch=False, fmt='table'):
fetch=True, fmt=fmt)
else:
data = data[list(columns)]
pprint(data, title=table.upper())

# Plot all the spectra
if plot and table.lower()=='spectra' and data:
if len(data)>5: print 'Only plotting first 5 spectra. Run plot_spectrum() method to plot individually.'

for i in self.query("SELECT id FROM spectra WHERE source_id={} LIMIT 5".format(source_id), unpack=True)[0]:
self.plot_spectrum(i)
pprint(data, title=table.upper())

else: pass

Expand Down Expand Up @@ -600,35 +593,45 @@ def output_spectrum(self, spectrum_id, filepath, original=False):

else: print "No spectrum found with id {}".format(spectrum_id)

def plot_spectrum(self, spectrum_id, overplot=False, color='b', norm=False):
def plot_spectrum(self, spectrum_id, table='spectra', column='spectrum', overplot=False, color='b', norm=False):
"""
Plots spectrum **ID** from SPECTRA table.
Parameters
----------
spectrum_id: int
The id from the SPECTRA table of the spectrum to plot.
The id from the table of the spectrum to plot.
overplot: bool
Overplot the spectrum
table: str
The table from which the plot is being made
column: str
The column with SPECTRUM data type to plot
color: str
The color used for the data
norm: bool, sequence
True or (min,max) wavelength range in which to normalize the spectrum
"""
i = self.query("SELECT * FROM spectra WHERE id={}".format(spectrum_id), fetch='one', fmt='dict')
i = self.query("SELECT * FROM {} WHERE id={}".format(table,spectrum_id), fetch='one', fmt='dict')
if i:
try:
spec = i['spectrum']
spec = i[column]
w, f, e = scrub(spec.data, units=False)

# Draw the axes and add the metadata
if not overplot:
fig, ax = plt.subplots()
plt.rc('text', usetex=False)
ax.set_yscale('log', nonposy='clip'), plt.title('source_id = {}'.format(i['source_id']))
plt.figtext(0.15,0.88, '{}\n{}\n{}\n{}'.format(i['filename'],self.query("SELECT name FROM telescopes WHERE id={}".format(i['telescope_id']), fetch='one')[0] if i['telescope_id'] else '',self.query("SELECT name FROM instruments WHERE id={}".format(i['instrument_id']), fetch='one')[0] if i['instrument_id'] else '',i['obs_date']), verticalalignment='top')
ax.set_xlabel(r'$\lambda$ [{}]'.format(i['wavelength_units'])), ax.set_ylabel(r'$F_\lambda$ [{}]'.format(i['flux_units'])), ax.legend(loc=8, frameon=False)
ax.set_yscale('log', nonposy='clip')
plt.title('source_id = {}'.format(i['source_id']))
plt.figtext(0.15,0.88, '\n'.join(['{}: {}'.format(k,v) for k,v in i.items() if k!=column]), \
verticalalignment='top')
try:
ax.set_xlabel(r'$\lambda$ [{}]'.format(i['wavelength_units']))
ax.set_ylabel(r'$F_\lambda$ [{}]'.format(i['flux_units']))
except: pass
ax.legend(loc=8, frameon=False)
else: ax = plt.gca()

# Normalize the data
Expand All @@ -650,10 +653,14 @@ def plot_spectrum(self, spectrum_id, overplot=False, color='b', norm=False):
X, Y = plt.xlim(), plt.ylim()
try: ax.fill_between(w, f-e, f+e, color=color, alpha=0.3), ax.set_xlim(X), ax.set_ylim(Y)
except: print 'No uncertainty array for spectrum {}'.format(spectrum_id)
plt.ion()

except: print "Could not plot spectrum {}".format(spectrum_id); plt.close()

else: print "No spectrum {} in the SPECTRA table.".format(spectrum_id)

def query(self, SQL, params='', fmt='array', fetch='all', unpack=False, export='', verbose=False, use_converters=True):
def query(self, SQL, params='', fmt='array', fetch='all', unpack=False, export='', \
verbose=False, use_converters=True):
"""
Wrapper for cursors so data can be retrieved as a list or dictionary from same method
Expand Down Expand Up @@ -713,7 +720,8 @@ def query(self, SQL, params='', fmt='array', fetch='all', unpack=False, export='
# Print the results to file
if export:
# If .vot or .xml, assume VOTable export with votools
if export.lower().endswith('.xml') or export.lower().endswith('.vot'): votools.dict_tovot(dictionary, export)
if export.lower().endswith('.xml') or export.lower().endswith('.vot'):
votools.dict_tovot(dictionary, export)

# Otherwise print as ascii
else: ii.write(table, export, Writer=ii.FixedWidthTwoLine, fill_values=[('None', '-')])
Expand Down

0 comments on commit 77e502e

Please sign in to comment.