Skip to content

Commit

Permalink
Don't rely on argsort to sort - directly match to order in parameter …
Browse files Browse the repository at this point in the history
…list
  • Loading branch information
astrofrog committed Oct 7, 2013
1 parent 94325ee commit e48ce0f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
9 changes: 6 additions & 3 deletions sedfitter/convolve/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from ..convolved_fluxes import ConvolvedFluxes
from ..sed import SED

from ..models import load_parameter_table
from .. import six

def convolve_model_dir(model_dir, filters, overwrite=False):
"""
Expand Down Expand Up @@ -44,6 +45,8 @@ def convolve_model_dir(model_dir, filters, overwrite=False):
glob.glob(model_dir + '/seds/*.fits') +
glob.glob(model_dir + '/seds/*/*.fits'))

par_table = load_parameter_table(model_dir)

if len(sed_files) == 0:
raise Exception("No SEDs found in %s" % model_dir)
else:
Expand All @@ -56,7 +59,7 @@ def convolve_model_dir(model_dir, filters, overwrite=False):
apertures = fits.open(sed_files[0], memmap=False)[2].data['APERTURE']

# Set up convolved fluxes
fluxes = [ConvolvedFluxes(model_names=np.zeros(len(sed_files), dtype='S30'), apertures=apertures, initialize_arrays=True) for i in range(len(filters))]
fluxes = [ConvolvedFluxes(model_names=np.zeros(len(sed_files), dtype='U30' if six.PY3 else 'S30'), apertures=apertures, initialize_arrays=True) for i in range(len(filters))]

# Set up list of binned filters
binned_filters = []
Expand Down Expand Up @@ -96,6 +99,6 @@ def convolve_model_dir(model_dir, filters, overwrite=False):
fluxes[i].error[im] = np.sqrt(np.sum((s.error * f.r) ** 2, axis=1))

for i, f in enumerate(binned_filters):
fluxes[i].sort_by_name()
fluxes[i].sort_to_match(par_table['MODEL_NAME'])
fluxes[i].write(model_dir + '/convolved/' + f.name + '.fits',
overwrite=overwrite)
9 changes: 6 additions & 3 deletions sedfitter/convolve/monochromatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from ..convolved_fluxes import ConvolvedFluxes
from ..sed import SED

from ..models import load_parameter_table
from .. import six

def convolve_model_dir_monochromatic(model_dir, overwrite=False, max_ram=8,
wav_min=-np.inf, wav_max=np.inf):
Expand Down Expand Up @@ -45,6 +46,8 @@ def convolve_model_dir_monochromatic(model_dir, overwrite=False, max_ram=8,
glob.glob(model_dir + '/seds/*.fits') +
glob.glob(model_dir + '/seds/*/*.fits'))

par_table = load_parameter_table(model_dir)

# Find number of models
n_models = len(sed_files)

Expand Down Expand Up @@ -93,7 +96,7 @@ def convolve_model_dir_monochromatic(model_dir, overwrite=False, max_ram=8,
log.info('Processing wavelengths {0} to {1}'.format(jmin, jmax))

# Set up convolved fluxes
fluxes = [ConvolvedFluxes(model_names=np.zeros(n_models, dtype='S30'), apertures=apertures, initialize_arrays=True) for i in range(chunk_size)]
fluxes = [ConvolvedFluxes(model_names=np.zeros(n_models, dtype='U30' if six.PY3 else 'S30'), apertures=apertures, initialize_arrays=True) for i in range(chunk_size)]

b = ProgressBar(len(sed_files))

Expand Down Expand Up @@ -122,7 +125,7 @@ def convolve_model_dir_monochromatic(model_dir, overwrite=False, max_ram=8,
fluxes[j].error[im, :] = s.error[:, j + jmin]

for j in range(chunk_size):
fluxes[j].sort_by_name()
fluxes[j].sort_to_match(par_table['MODEL_NAME'])
fluxes[j].write('{0:s}/convolved/MO{1:03d}.fits'.format(model_dir, j + jmin + 1),
overwrite=overwrite)
filters['filter'][j + jmin] = "MO{0:03d}".format(j + jmin + 1)
Expand Down
17 changes: 14 additions & 3 deletions sedfitter/convolved_fluxes/convolved_fluxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,22 @@ def __eq__(self, other):
and np.all(self.flux == other.flux) \
and np.all(self.error == other.error)

def sort_by_name(self):
def sort_to_match(self, requested_model_names):
"""
Sort the models by model name
"""
order = np.argsort(self.model_names)

order = np.arange(self.n_models)

subset = np.in1d(requested_model_names, self.model_names)
order = order[subset]
index = np.argsort(self.model_names)
order = order[index]

# Double check that the sorting will work
if not np.all(self.model_names[order] == requested_model_names):
raise Exception("Sorting failed")

self.model_names = self.model_names[order]
self.flux = self.flux[order, :]
self.error = self.error[order, :]
Expand Down Expand Up @@ -236,7 +247,7 @@ def write(self, filename, overwrite=False):
from astropy.table import Table, Column

tc = Table()
tc['MODEL_NAME'] = self.model_names
tc['MODEL_NAME'] = self.model_names.astype('S30')
tc['TOTAL_FLUX'] = self.flux
tc['TOTAL_FLUX_ERR'] = self.error

Expand Down

0 comments on commit e48ce0f

Please sign in to comment.