Skip to content

Commit

Permalink
Merge f807bc3 into ebb1bdd
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Feb 14, 2019
2 parents ebb1bdd + f807bc3 commit f292815
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 11 deletions.
30 changes: 20 additions & 10 deletions specutils/io/registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,29 @@ def wrapper(*args, **kwargs):
return False
return wrapper

identifier = identifier_wrapper(identifier)

def decorator(func):
io_registry.register_reader(label, dtype, func)

# If the identifier is not defined, but the extensions are, create
# a simple identifier based off file extension.
if identifier is None and extensions is not None:
io_registry.register_identifier(
label, dtype, lambda *args, **kwargs: any([args[0].endswith(x)
for x in extensions]))
if identifier is None:
# If the identifier is not defined, but the extensions are, create
# a simple identifier based off file extension.
if extensions is not None:
logging.info("'{}' data loader provided for {} without "
"explicit identifier. Creating identifier using "
"list of compatible extensions".format(
label, dtype.__name__))
id_func = lambda *args, **kwargs: any([args[1].endswith(x)
for x in extensions])
# Otherwise, create a dummy identifier
else:
logging.warning("'{}' data loader provided for {} without "
"explicit identifier or list of compatible "
"extensions".format(label, dtype.__name__))
id_func = lambda *args, **kwargs: True
else:
io_registry.register_identifier(label, dtype, identifier)
id_func = identifier_wrapper(identifier)

io_registry.register_identifier(label, dtype, id_func)

# Include the file extensions as attributes on the function object
func.extensions = extensions
Expand Down Expand Up @@ -90,7 +100,7 @@ def load_spectrum_list(*args, **kwargs):
load_spectrum_list.priority = priority

io_registry.register_reader(label, SpectrumList, load_spectrum_list)
io_registry.register_identifier(label, SpectrumList, identifier)
io_registry.register_identifier(label, SpectrumList, id_func)
logging.debug("Created SpectrumList reader for \"{}\".".format(label))

@wraps(func)
Expand Down
87 changes: 86 additions & 1 deletion specutils/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import pytest
import warnings

from specutils import SpectrumList
from specutils import Spectrum1D, SpectrumList
from specutils.io import data_loader


def test_generic_spectrum_from_table(recwarn):
Expand Down Expand Up @@ -71,3 +72,87 @@ def test_speclist_autoidentify():

formats = registry.get_formats(SpectrumList)
assert (formats['Auto-identify'] == 'Yes').all()


def test_default_identifier(tmpdir):

fname = str(tmpdir.join('empty.txt'))
with open(fname, 'w') as ff:
ff.write('\n')

format_name = 'default_identifier_test'

@data_loader(format_name)
def reader(*args, **kwargs):
"""Doesn't actually get used."""
return

for datatype in [Spectrum1D, SpectrumList]:
fmts = registry.identify_format('read', datatype, fname, None, [], {})
assert format_name in fmts

# Clean up after ourselves
registry.unregister_reader(format_name, datatype)
registry.unregister_identifier(format_name, datatype)


def test_default_identifier_extension(tmpdir):

good_fname = str(tmpdir.join('empty.fits'))
bad_fname = str(tmpdir.join('empty.txt'))

# Create test data files.
for name in [good_fname, bad_fname]:
with open(name, 'w') as ff:
ff.write('\n')

format_name = 'default_identifier_extension_test'

@data_loader(format_name, extensions=['fits'])
def reader(*args, **kwargs):
"""Doesn't actually get used."""
return

for datatype in [Spectrum1D, SpectrumList]:
fmts = registry.identify_format('read', datatype, good_fname, None, [], {})
assert format_name in fmts

fmts = registry.identify_format('read', datatype, bad_fname, None, [], {})
assert format_name not in fmts

# Clean up after ourselves
registry.unregister_reader(format_name, datatype)
registry.unregister_identifier(format_name, datatype)


def test_custom_identifier(tmpdir):

good_fname = str(tmpdir.join('good.txt'))
bad_fname = str(tmpdir.join('bad.txt'))

# Create test data files.
for name in [good_fname, bad_fname]:
with open(name, 'w') as ff:
ff.write('\n')

format_name = 'custom_identifier_test'

def identifier(origin, *args, **kwargs):
fname = args[0]
return 'good' in fname

@data_loader(format_name, identifier=identifier)
def reader(*args, **kwargs):
"""Doesn't actually get used."""
return

for datatype in [Spectrum1D, SpectrumList]:
fmts = registry.identify_format('read', datatype, good_fname, None, [], {})
assert format_name in fmts

fmts = registry.identify_format('read', datatype, bad_fname, None, [], {})
assert format_name not in fmts

# Clean up after ourselves
registry.unregister_reader(format_name, datatype)
registry.unregister_identifier(format_name, datatype)

0 comments on commit f292815

Please sign in to comment.