Skip to content

Commit

Permalink
Merge pull request #5 from ConorMacBride/merge-results
Browse files Browse the repository at this point in the history
Added function to merge FITS files produced by `mcalf.models.results.FitResults.save`
  • Loading branch information
ConorMacBride committed Dec 23, 2020
2 parents d9e4c47 + dce837a commit b28d860
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 2 deletions.
212 changes: 211 additions & 1 deletion src/mcalf/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from shutil import copyfile

import numpy as np
from astropy.io import fits
from scipy.io import readsav


__all__ = ['make_iter', 'load_parameter']
__all__ = ['make_iter', 'load_parameter', 'merge_results']


def make_iter(*args):
Expand Down Expand Up @@ -141,3 +142,212 @@ def load_parameter(parameter, wl=None):
value = float(value)

return value


def merge_results(filenames, output):
"""Merges files generated by the `mcalf.models.results.FitResults.save` method
Parameters
----------
filenames : list of str, length>1
List of FITS files generated by `mcalf.models.results.FitResults.save` method.
output : str
Name of FITS file to save merged input files to. Will be clobbered.
Notes
-----
See `mcalf.models.results.FitResults` for details on the output FITS file data structure.
"""
if not isinstance(filenames, list) or len(filenames) <= 1:
raise TypeError("`filenames` must be a list of length greater than 1.")

# Verification headers (initialise and give keys)
verification = {
'PRIMARY': {
'NTIME': None,
'NROWS': None,
'NCOLS': None,
'TIME': None,
},
'PARAMETERS': {
'NPARAMS': None,
},
'CLASSIFICATIONS': {
},
'PROFILE': {
'PROFILES': None
},
'SUCCESS': {
},
'CHI2': {
},
'VLOSA': {
'VTYPE': None,
'UNIT': None,
},
'VLOSQ': {
'VTYPE': None,
'UNIT': None,
},
}

# Values if not fitted (or unsuccessful)
unset_value = {
'PRIMARY': '__SKIP__',
'PARAMETERS': np.nan,
'CLASSIFICATIONS': -1,
'PROFILE': 0,
'SUCCESS': False,
'CHI2': np.nan,
'VLOSA': np.nan,
'VLOSQ': np.nan,
}

# Open the output file for updating
main_hdul = fits.open(filenames[0], mode='readonly')

# Record the order for easy access {'NAME': index, ...}
main_index = {main_hdul[v].name: v for v in range(len(main_hdul))}

# Remove optional keys if not present in first file
for optional_key in ['VLOSA', 'VLOSQ']:
if optional_key not in main_index.keys():
verification.pop(optional_key)

# Check that the expected HDUs are present
if main_index.keys() != verification.keys():
raise ValueError(f"Unexpected HDU name in {filenames[0]}.")

# Get expected values for the headers from the first file
for name in verification.keys():
for attribute in verification[name].keys():
verification[name][attribute] = main_hdul[main_index[name]].header[attribute]

# Load the initial arrays
arrays = {name: main_hdul[main_index[name]].data.copy() for name in verification.keys()}

# Close the first input file
main_hdul.close()

# Copy across the remainder of the FITS files
for filename in filenames[1:]:
with fits.open(filename, mode='readonly') as hdul:

# Check that the expected HDUs are present in `filename`
input_index = {hdul[v].name: v for v in range(len(hdul))}
if input_index.keys() != verification.keys():
raise ValueError(f"Unexpected HDUs in {filename}.")

for name in verification.keys(): # Loop through the HDUs

# Verify that the important header items match
for attribute, expected_value in verification[name].items():
if hdul[input_index[name]].header[attribute] != expected_value:
# TODO: Handle the case where there are different profiles in each file
raise ValueError(f"FITS attribute {attribute} for {name} HDU in {filename} is different.")

# Create aliases for the input and output arrays
output_array = arrays[name]
input_array = hdul[input_index[name]].data

# Choose the function to test if data is being overwritten
invalid = unset_value[name]
if invalid == '__SKIP__': # PRIMARY HDU (do nothing)
continue
elif np.isnan(invalid): # floats (can only overwrite nan)
test_function = _nan_test
elif isinstance(invalid, bool) and not invalid: # bool (can only overwrite False)
test_function = _false_test
elif isinstance(invalid, int) and invalid == -1:
test_function = _minus_one_test
elif isinstance(invalid, int) and invalid == 0:
test_function = _zero_test
else:
raise ValueError(f"Unexpected invalid value {invalid}.")

# Verify that no data is being overwritten
should_edit = test_function(input_array)
would_edit = output_array[should_edit]
if np.sum(test_function(would_edit)) != 0:
raise ValueError(f"Overlapping values in {name} HDU at {filename}.")

# Merge `input_array` onto output
output_array[np.where(should_edit)] = input_array[np.where(should_edit)]

# Copy the first FITS input to the output file
copyfile(filenames[0], output)

# Open the output file for updating
with fits.open(output, mode='update') as output_hdul:
for hdu in output_hdul:
hdu.data = arrays[hdu.name]


def _nan_test(x):
"""Finds where not NaN
False if index is NaN.
Parameters
----------
x : array_like
Array to search.
Returns
-------
array : array of bool
Whether corresponding index is not NaN.
"""
return ~np.isnan(x)


def _false_test(x):
"""Finds where not False (where is True)
Parameters
----------
x : array_like
Array to search.
Returns
-------
array : array of bool
Whether corresponding index is True. (Is not False.)
Notes
-----
Converts to bool dtype as integer could have been given.
"""
return x.astype(bool)


def _minus_one_test(x):
"""Finds where not -1
Parameters
----------
x : array_like
Array to search.
Returns
-------
array : array of bool
Whether corresponding index is not -1.
"""
return x != -1


def _zero_test(x):
"""Finds where not 0
Parameters
----------
x : array_like
Array to search.
Returns
-------
array : array of bool
Whether corresponding index is not 0.
"""
return x != 0
Binary file added tests/utils/data/test_merge_results_1.fits
Binary file not shown.
Binary file not shown.
Binary file added tests/utils/data/test_merge_results_2.fits
Binary file not shown.
Binary file not shown.
Binary file added tests/utils/data/test_merge_results_3.fits
Binary file not shown.
Binary file added tests/utils/data/test_merge_results_all.fits
Binary file not shown.
70 changes: 69 additions & 1 deletion tests/utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
import os
import numpy as np
from astropy.io import fits

from mcalf.utils.misc import make_iter, load_parameter
from mcalf.utils.misc import make_iter, load_parameter, merge_results


def test_make_iter():
Expand Down Expand Up @@ -46,3 +47,70 @@ def test_load_parameter_file():
for e in ['typeerror', 'syntaxerror']:
with pytest.raises(SyntaxError): # Note: the TypeError is converted into a SyntaxError
res = load_parameter(f"{abspath}test_load_parameter_file_{e}.csv", wl=2523.43)


def test_merge_results(tmp_path):

abspath = f"{os.path.dirname(os.path.abspath(__file__))}{os.path.sep}data{os.path.sep}"

# Compatible files to test merging
compatible_files = [
abspath + "test_merge_results_1.fits",
abspath + "test_merge_results_2.fits",
abspath + "test_merge_results_3.fits",
]

# Merge and save
output_file = tmp_path / "test_merge_results_output.fits"
merge_results(compatible_files, output_file)

# Compare merged files to expected merge
test = fits.open(output_file, mode='readonly')
verify = fits.open(abspath + "test_merge_results_all.fits", mode='readonly')
# Diff ignoring checksums as too strict (compare values instead)
diff = fits.FITSDiff(test, verify, ignore_keywords=['CHECKSUM', 'DATASUM'])
assert diff.identical # If this fails tolerances *may* need to be adjusted

# Incompatible files to test merging
incompatible_files = [
abspath + "test_merge_results_1.fits",
abspath + "test_merge_results_2.fits",
abspath + "test_merge_results_3.fits",
abspath + "test_merge_results_2.fits", # Duplicate (overlapping) file should fail
]

# Merge (should fail before saving)
with pytest.raises(ValueError):
merge_results(incompatible_files, output_file)

# Compatible files but wrong time for one (should fail)
compatible_files_wrongtime = [
abspath + "test_merge_results_1.fits",
abspath + "test_merge_results_2_wrongtime.fits",
abspath + "test_merge_results_3.fits",
]

# Merge (should fail before saving)
with pytest.raises(ValueError):
merge_results(compatible_files_wrongtime, output_file)

# Must provide a list of multiple files
with pytest.raises(TypeError): # single string
merge_results(abspath + "test_merge_results_1.fits", output_file)
with pytest.raises(TypeError): # list of length 1
merge_results([abspath + "test_merge_results_1.fits"], output_file)

# The extra HDU should cause an error
with pytest.raises(ValueError) as excinfo:
merge_results([
abspath + "test_merge_results_2.fits",
abspath + "test_merge_results_1_extrahdu.fits",
], output_file)
assert 'nexpected' in str(excinfo.value) # "Unexpected"
# reverse (now extra is in first file)
with pytest.raises(ValueError) as excinfo:
merge_results([
abspath + "test_merge_results_1_extrahdu.fits",
abspath + "test_merge_results_2.fits",
], output_file)
assert 'nexpected' in str(excinfo.value) # "Unexpected"

0 comments on commit b28d860

Please sign in to comment.