Skip to content

Commit

Permalink
Merge pull request #14941 from meeseeksmachine/auto-backport-of-pr-14…
Browse files Browse the repository at this point in the history
…907-on-v5.3.x

Backport PR #14907 on branch v5.3.x (Specify stable sort in indexing)
  • Loading branch information
pllim committed Jun 14, 2023
2 parents 23470c0 + 3cca549 commit 31979dd
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 7 deletions.
4 changes: 2 additions & 2 deletions astropy/table/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, columns, engine=None, unique=False):
raise ValueError("Cannot create index without at least one column")
elif len(columns) == 1:
col = columns[0]
row_index = Column(col.argsort())
row_index = Column(col.argsort(kind="stable"))
data = Table([col[row_index]])
else:
num_rows = len(columns[0])
Expand All @@ -117,7 +117,7 @@ def __init__(self, columns, engine=None, unique=False):
try:
lines = table[np.lexsort(sort_columns)]
except TypeError: # arbitrary mixins might not work with lexsort
lines = table[table.argsort()]
lines = table[table.argsort(kind="stable")]
data = lines[lines.colnames[:-1]]
row_index = lines[lines.colnames[-1]]

Expand Down
20 changes: 20 additions & 0 deletions astropy/table/tests/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,3 +690,23 @@ def test_group_mixins_unsupported(col):
tg = t.group_by("a")
with pytest.warns(AstropyUserWarning, match="Cannot aggregate column 'mix'"):
tg.groups.aggregate(np.sum)


@pytest.mark.parametrize("add_index", [False, True])
def test_group_stable_sort(add_index):
"""Test that group_by preserves the order of the table.
This table has 5 groups with an average of 200 rows per group, so it is not
statistically possible that the groups will be in order by chance.
This tests explicitly the case where grouping is done via the index sort.
See: https://github.com/astropy/astropy/issues/14882
"""
a = np.random.randint(0, 5, 1000)
b = np.arange(len(a))
t = Table([a, b], names=["a", "b"])
if add_index:
t.add_index("a")
tg = t.group_by("a")
for grp in tg.groups:
assert np.all(grp["b"] == np.sort(grp["b"]))
25 changes: 20 additions & 5 deletions astropy/time/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,13 +1441,28 @@ def argmax(self, axis=None, out=None):

return dt.argmax(axis, out)

def argsort(self, axis=-1):
def argsort(self, axis=-1, kind="stable"):
"""Returns the indices that would sort the time array.
This is similar to :meth:`~numpy.ndarray.argsort`, but adapted to ensure
that the full precision given by the two doubles ``jd1`` and ``jd2``
is used, and that corresponding attributes are copied. Internally,
it uses :func:`~numpy.lexsort`, and hence no sort method can be chosen.
This is similar to :meth:`~numpy.ndarray.argsort`, but adapted to ensure that
the full precision given by the two doubles ``jd1`` and ``jd2`` is used, and
that corresponding attributes are copied. Internally, it uses
:func:`~numpy.lexsort`, and hence no sort method can be chosen.
Parameters
----------
axis : int, optional
Axis along which to sort. Default is -1, which means sort along the last
axis.
kind : 'stable', optional
Sorting is done with :func:`~numpy.lexsort` so this argument is ignored, but
kept for compatibility with :func:`~numpy.argsort`. The sorting is stable,
meaning that the order of equal elements is preserved.
Returns
-------
indices : ndarray
An array of indices that sort the time array.
"""
# For procedure, see comment on argmin.
jd1, jd2 = self.jd1, self.jd2
Expand Down
3 changes: 3 additions & 0 deletions docs/changes/table/14907.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix a bug where table indexes were not using a stable sort order. This was causing the
order of rows within groups to not match the original table order when an indexed table
was grouped.
7 changes: 7 additions & 0 deletions docs/table/operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ values and the indices of the group boundaries for those key values. The groups
here correspond to the row slices ``0:4``, ``4:7``, and ``7:10`` in the
``obs_by_name`` table.

The output grouped table has two important properties:

- The groups in the order of the lexically sorted key values (``M101``, ``M31``,
``M82`` in our example).
- The rows within each group are in the same order as they appear in the
original table.

The initial argument (``keys``) for the :func:`~astropy.table.Table.group_by`
function can take a number of input data types:

Expand Down

0 comments on commit 31979dd

Please sign in to comment.