Skip to content

Commit

Permalink
Merge pull request #13974 from WilliamJamieson/black/io_other
Browse files Browse the repository at this point in the history
Apply `black` to remaining parts of `astropy.io`
  • Loading branch information
astrofrog committed Nov 14, 2022
2 parents 88196df + c375cfb commit a8cb166
Show file tree
Hide file tree
Showing 9 changed files with 372 additions and 253 deletions.
139 changes: 83 additions & 56 deletions astropy/io/registry/base.py
Expand Up @@ -9,17 +9,18 @@

import numpy as np

__all__ = ['IORegistryError']
__all__ = ["IORegistryError"]


class IORegistryError(Exception):
"""Custom error for registry clashes.
"""
"""Custom error for registry clashes."""

pass


# -----------------------------------------------------------------------------


class _UnifiedIORegistryBase(metaclass=abc.ABCMeta):
"""Base class for registries in Astropy's Unified IO.
Expand All @@ -44,7 +45,7 @@ def __init__(self):
# what this class can do: e.g. 'read' &/or 'write'
self._registries = dict()
self._registries["identify"] = dict(attr="_identifiers", column="Auto-identify")
self._registries_order = ("identify", ) # match keys in `_registries`
self._registries_order = ("identify",) # match keys in `_registries`

# If multiple formats are added to one class the update of the docs is quite
# expensive. Classes for which the doc update is temporarly delayed are added
Expand Down Expand Up @@ -87,12 +88,16 @@ def get_formats(self, data_class=None, filter_on=None):

# set up the column names
colnames = (
"Data class", "Format",
"Data class",
"Format",
*[self._registries[k]["column"] for k in self._registries_order],
"Deprecated")
"Deprecated",
)
i_dataclass = colnames.index("Data class")
i_format = colnames.index("Format")
i_regstart = colnames.index(self._registries[self._registries_order[0]]["column"])
i_regstart = colnames.index(
self._registries[self._registries_order[0]]["column"]
)
i_deprecated = colnames.index("Deprecated")

# registries
Expand All @@ -103,52 +108,64 @@ def get_formats(self, data_class=None, filter_on=None):
# the format classes from all registries except "identify"

rows = []
for (fmt, cls) in format_classes:
for fmt, cls in format_classes:
# see if can skip, else need to document in row
if (data_class is not None and not self._is_best_match(
data_class, cls, format_classes)):
if data_class is not None and not self._is_best_match(
data_class, cls, format_classes
):
continue

# flags for each registry
has_ = {k: "Yes" if (fmt, cls) in getattr(self, v["attr"]) else "No"
for k, v in self._registries.items()}
has_ = {
k: "Yes" if (fmt, cls) in getattr(self, v["attr"]) else "No"
for k, v in self._registries.items()
}

# Check if this is a short name (e.g. 'rdb') which is deprecated in
# favor of the full 'ascii.rdb'.
ascii_format_class = ('ascii.' + fmt, cls)
ascii_format_class = ("ascii." + fmt, cls)
# deprecation flag
deprecated = "Yes" if ascii_format_class in format_classes else ""

# add to rows
rows.append((cls.__name__, fmt,
*[has_[n] for n in self._registries_order], deprecated))
rows.append(
(
cls.__name__,
fmt,
*[has_[n] for n in self._registries_order],
deprecated,
)
)

# filter_on can be in self_registries_order or None
if str(filter_on).lower() in self._registries_order:
index = self._registries_order.index(str(filter_on).lower())
rows = [row for row in rows if row[i_regstart + index] == 'Yes']
rows = [row for row in rows if row[i_regstart + index] == "Yes"]
elif filter_on is not None:
raise ValueError('unrecognized value for "filter_on": {0}.\n'
f'Allowed are {self._registries_order} and None.')
raise ValueError(
'unrecognized value for "filter_on": {0}.\n'
f"Allowed are {self._registries_order} and None."
)

# Sorting the list of tuples is much faster than sorting it after the
# table is created. (#5262)
if rows:
# Indices represent "Data Class", "Deprecated" and "Format".
data = list(zip(*sorted(
rows, key=itemgetter(i_dataclass, i_deprecated, i_format))))
data = list(
zip(*sorted(rows, key=itemgetter(i_dataclass, i_deprecated, i_format)))
)
else:
data = None

# make table
# need to filter elementwise comparison failure issue
# https://github.com/numpy/numpy/issues/6784
with warnings.catch_warnings():
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action="ignore", category=FutureWarning)

format_table = Table(data, names=colnames)
if not np.any(format_table['Deprecated'].data == 'Yes'):
format_table.remove_column('Deprecated')
if not np.any(format_table["Deprecated"].data == "Yes"):
format_table.remove_column("Deprecated")

return format_table

Expand Down Expand Up @@ -238,9 +255,10 @@ def my_identifier(*args, **kwargs):
if not (data_format, data_class) in self._identifiers or force:
self._identifiers[(data_format, data_class)] = identifier
else:
raise IORegistryError("Identifier for format '{}' and class '{}' is "
'already defined'.format(data_format,
data_class.__name__))
raise IORegistryError(
f"Identifier for format {data_format!r} and class"
f" {data_class.__name__!r} is already defined"
)

def unregister_identifier(self, data_format, data_class):
"""
Expand All @@ -256,8 +274,10 @@ def unregister_identifier(self, data_format, data_class):
if (data_format, data_class) in self._identifiers:
self._identifiers.pop((data_format, data_class))
else:
raise IORegistryError("No identifier defined for format '{}' and class"
" '{}'".format(data_format, data_class.__name__))
raise IORegistryError(
f"No identifier defined for format {data_format!r} and class"
f" {data_class.__name__!r}"
)

def identify_format(self, origin, data_class_required, path, fileobj, args, kwargs):
"""Loop through identifiers to see which formats match.
Expand Down Expand Up @@ -291,7 +311,8 @@ def identify_format(self, origin, data_class_required, path, fileobj, args, kwar
for data_format, data_class in self._identifiers:
if self._is_best_match(data_class_required, data_class, self._identifiers):
if self._identifiers[(data_format, data_class)](
origin, path, fileobj, *args, **kwargs):
origin, path, fileobj, *args, **kwargs
):
valid_formats.append(data_format)

return valid_formats
Expand All @@ -302,8 +323,8 @@ def identify_format(self, origin, data_class_required, path, fileobj, args, kwar
def _get_format_table_str(self, data_class, filter_on):
"""``get_formats()``, without column "Data class", as a str."""
format_table = self.get_formats(data_class, filter_on)
format_table.remove_column('Data class')
format_table_str = '\n'.join(format_table.pformat(max_lines=-1))
format_table.remove_column("Data class")
format_table_str = "\n".join(format_table.pformat(max_lines=-1))
return format_table_str

def _is_best_match(self, class1, class2, format_classes):
Expand Down Expand Up @@ -334,11 +355,12 @@ def _get_valid_format(self, mode, cls, path, fileobj, args, kwargs):

if len(valid_formats) == 0:
format_table_str = self._get_format_table_str(cls, mode.capitalize())
raise IORegistryError("Format could not be identified based on the"
" file name or contents, please provide a"
" 'format' argument.\n"
"The available formats are:\n"
"{}".format(format_table_str))
raise IORegistryError(
"Format could not be identified based on the"
" file name or contents, please provide a"
" 'format' argument.\n"
f"The available formats are:\n{format_table_str}"
)
elif len(valid_formats) > 1:
return self._get_highest_priority_format(mode, cls, valid_formats)

Expand All @@ -357,14 +379,14 @@ def _get_highest_priority_format(self, mode, cls, valid_formats):
mode_loader = "writer"

best_formats = []
current_priority = - np.inf
current_priority = -np.inf
for format in valid_formats:
try:
_, priority = format_dict[(format, cls)]
except KeyError:
# We could throw an exception here, but get_reader/get_writer handle
# this case better, instead maximally deprioritise the format.
priority = - np.inf
priority = -np.inf

if priority == current_priority:
best_formats.append(format)
Expand All @@ -373,9 +395,10 @@ def _get_highest_priority_format(self, mode, cls, valid_formats):
current_priority = priority

if len(best_formats) > 1:
raise IORegistryError("Format is ambiguous - options are: {}".format(
', '.join(sorted(valid_formats, key=itemgetter(0)))
))
raise IORegistryError(
"Format is ambiguous - options are:"
f" {', '.join(sorted(valid_formats, key=itemgetter(0)))}"
)
return best_formats[0]

def _update__doc__(self, data_class, readwrite):
Expand All @@ -390,7 +413,7 @@ def _update__doc__(self, data_class, readwrite):

from .interface import UnifiedReadWrite

FORMATS_TEXT = 'The available built-in formats are:'
FORMATS_TEXT = "The available built-in formats are:"

# Get the existing read or write method and its docstring
class_readwrite_func = getattr(data_class, readwrite)
Expand All @@ -410,38 +433,42 @@ def _update__doc__(self, data_class, readwrite):
lines = lines[:chop_index]

# Find the minimum indent, skipping the first line because it might be odd
matches = [re.search(r'(\S)', line) for line in lines[1:]]
left_indent = ' ' * min(match.start() for match in matches if match)
matches = [re.search(r"(\S)", line) for line in lines[1:]]
left_indent = " " * min(match.start() for match in matches if match)

# Get the available unified I/O formats for this class
# Include only formats that have a reader, and drop the 'Data class' column
format_table = self.get_formats(data_class, readwrite.capitalize())
format_table.remove_column('Data class')
format_table.remove_column("Data class")

# Get the available formats as a table, then munge the output of pformat()
# a bit and put it into the docstring.
new_lines = format_table.pformat(max_lines=-1, max_width=80)
table_rst_sep = re.sub('-', '=', new_lines[1])
table_rst_sep = re.sub("-", "=", new_lines[1])
new_lines[1] = table_rst_sep
new_lines.insert(0, table_rst_sep)
new_lines.append(table_rst_sep)

# Check for deprecated names and include a warning at the end.
if 'Deprecated' in format_table.colnames:
new_lines.extend(['',
'Deprecated format names like ``aastex`` will be '
'removed in a future version. Use the full ',
'name (e.g. ``ascii.aastex``) instead.'])

new_lines = [FORMATS_TEXT, ''] + new_lines
if "Deprecated" in format_table.colnames:
new_lines.extend(
[
"",
"Deprecated format names like ``aastex`` will be "
"removed in a future version. Use the full ",
"name (e.g. ``ascii.aastex``) instead.",
]
)

new_lines = [FORMATS_TEXT, ""] + new_lines
lines.extend([left_indent + line for line in new_lines])

# Depending on Python version and whether class_readwrite_func is
# an instancemethod or classmethod, one of the following will work.
if isinstance(class_readwrite_func, UnifiedReadWrite):
class_readwrite_func.__class__.__doc__ = '\n'.join(lines)
class_readwrite_func.__class__.__doc__ = "\n".join(lines)
else:
try:
class_readwrite_func.__doc__ = '\n'.join(lines)
class_readwrite_func.__doc__ = "\n".join(lines)
except AttributeError:
class_readwrite_func.__func__.__doc__ = '\n'.join(lines)
class_readwrite_func.__func__.__doc__ = "\n".join(lines)
20 changes: 15 additions & 5 deletions astropy/io/registry/compat.py
Expand Up @@ -4,11 +4,21 @@

from .core import UnifiedIORegistry

__all__ = ["register_reader", "register_writer", "register_identifier", # noqa: F822
"unregister_reader", "unregister_writer", "unregister_identifier",
"get_reader", "get_writer", "get_formats",
"read", "write",
"identify_format", "delay_doc_updates"]
__all__ = [
"register_reader",
"register_writer",
"register_identifier",
"unregister_reader",
"unregister_writer",
"unregister_identifier",
"get_reader",
"get_writer",
"get_formats",
"read",
"write",
"identify_format",
"delay_doc_updates",
]

# make a default global-state registry (not publicly scoped, but often accessed)
# this is for backward compatibility when ``io.registry`` was a file.
Expand Down

0 comments on commit a8cb166

Please sign in to comment.