Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "sortedlist" option for faster rolling median (#26) #27

Merged
merged 1 commit into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Rolling objects to apply statistical operations to the window.
| Object | Update | Memory | Description | Builtin |
| ---------------- |:--------:|:------:|-----------------------------------------------------------------|----------------------|
| `Mean` | O(1) | O(k) | Arithmetic mean of window values | [`statistics.mean`](https://docs.python.org/3.9/library/statistics.html#statistics.mean) |
| `Median` | O(log k) | O(k) | Median value of window | [`statistics.median`](https://docs.python.org/3.9/library/statistics.html#statistics.median) |
| `Median` | O(log k) | O(k) | Median value of window: O(log k) update if 'skiplist' used | [`statistics.median`](https://docs.python.org/3.9/library/statistics.html#statistics.median) |
| `Mode` | O(1) | O(k) | Set of most frequently appearing values in window | [`statistics.multimode`](https://docs.python.org/3.9/library/statistics.html#statistics.multimode) |
| `Var` | O(1) | O(k) | Variance of window, with specified degrees of freedom | [`statistics.pvariance`](https://docs.python.org/3.9/library/statistics.html#statistics.pvariance) |
| `Std` | O(1) | O(k) | Standard deviation of window, with specified degrees of freedom | [`statistics.pstdev`](https://docs.python.org/3.9/library/statistics.html#statistics.pstdev) |
Expand Down
66 changes: 47 additions & 19 deletions rolling/stats/median.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,70 @@

from rolling.base import RollingObject
from rolling.structures.skiplist import IndexableSkiplist
from rolling.structures.sorted_list import SortedList


class Median(RollingObject):
"""
Iterator object that computes the median value
of a rolling window over a Python iterable.
Median value of a rolling window over a Python iterable.

Parameters
----------

iterable : any iterable object
window_size : integer, the size of the rolling
window moving over the iterable
window_type : 'fixed' (default) or 'variable'
tracker : 'sortedlist' (default) or 'skiplist'
data structure used to track the order of the window values

Complexity
----------

Update time: O(log k)
Memory usage: O(k)
For 'sortedlist' tracker:

where k is the size of the rolling window
Update time: O(k)
Memory usage: O(k)

For 'skiplist' tracker:

Update time: O(log k)
Memory usage: O(k)

where k is the size of the rolling window.

Note that the 'sortedlist' tracker may be faster for smaller
window sizes due to the overhead of skiplist operations.

Notes
-----

An indexable skiplist is used to track the median
as the window moves (using an idea of R. Hettinger [1]).
The indexable skiplist to track the median uses an idea and
code of R. Hettinger [1].

[1] http://code.activestate.com/recipes/576930/

"""
def __init__(
self,
iterable,
window_size,
window_type="fixed",
tracker="sortedlist",
):

def _init_fixed(self, iterable, window_size, **kwargs):
self._buffer = deque(maxlen=window_size)
self._skiplist = IndexableSkiplist(window_size)

if tracker == "skiplist":
self._tracker = IndexableSkiplist(window_size)
elif tracker == "sortedlist":
self._tracker = SortedList()
else:
raise ValueError(f"tracker must be one of 'skiplist' or 'sortedlist'")

super().__init__(iterable, window_size, window_type)

def _init_fixed(self, iterable, window_size, **kwargs):
# update buffer and skiplist with initial values
for new in islice(self._iterator, window_size - 1):
self._add_new(new)
Expand All @@ -47,38 +75,38 @@ def _init_fixed(self, iterable, window_size, **kwargs):
# insert a dummy value (the last element seen) so that
# the window is full and iterator works as expected
self._buffer.appendleft(new)
self._skiplist.insert(new)
self._tracker.insert(new)
except UnboundLocalError:
# if we didn't see any elements (the iterable had no
# elements or just one element), just use 0 instead
self._buffer.appendleft(0)
self._skiplist.insert(0)
self._tracker.insert(0)

def _init_variable(self, iterable, window_size, **kwargs):
self._buffer = deque(maxlen=window_size)
self._skiplist = IndexableSkiplist(window_size)
# no further initialisation required for variable-size windows
pass

def _update_window(self, new):
old = self._buffer.popleft()
self._skiplist.remove(old)
self._skiplist.insert(new)
self._tracker.remove(old)
self._tracker.insert(new)
self._buffer.append(new)

def _add_new(self, new):
self._skiplist.insert(new)
self._tracker.insert(new)
self._buffer.append(new)

def _remove_old(self):
old = self._buffer.popleft()
self._skiplist.remove(old)
self._tracker.remove(old)

@property
def current_value(self):
if self._obs % 2 == 1:
return self._skiplist[self._obs // 2]
return self._tracker[self._obs // 2]
else:
i = self._obs // 2
return (self._skiplist[i] + self._skiplist[i - 1]) / 2
return (self._tracker[i] + self._tracker[i - 1]) / 2

@property
def _obs(self):
Expand Down
45 changes: 45 additions & 0 deletions rolling/structures/sorted_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import bisect
from typing import MutableSequence


class SortedList(MutableSequence):
"""
Sorted list with an insert method to maintain order.

This is a very basic version of SortedContainer's
SortedList object [1], which uses the Python bisect
module's bisect/insort methods [2] to efficiently locate
the correct indices for insertion and removal of values.

[1] grantjenks.com/docs/sortedcontainers/_modules/sortedcontainers/sortedlist.html#SortedList
[2] docs.python.org/3/library/bisect.html

"""
def __init__(self):
self._list = []

def remove(self, value):
index = bisect.bisect_left(self._list, value)
if index >= len(self) or self._list[index] != value:
raise ValueError(f"Value not found: {value}")
self._list.pop(index)

def insert(self, value):
bisect.insort(self._list, value)

def __len__(self):
return len(self._list)

def __getitem__(self, index):
return self._list[index]

def __setitem__(self, index, value):
self._list[index] = value

def __delitem__(self, index):
del self._list[index]

def __eq__(self, other):
if isinstance(other, SortedList):
return self._list == other._list
return self._list == other
45 changes: 45 additions & 0 deletions tests/structures/test_sorted_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from rolling.structures.sorted_list import SortedList


def test_sorted_list():

sorted_list = SortedList()
assert sorted_list == []

with pytest.raises(ValueError):
sorted_list.remove(12345)

sorted_list.insert(3)
assert sorted_list == [3]

sorted_list.insert(2)
assert sorted_list == [2, 3]

sorted_list.insert(2)
assert sorted_list == [2, 2, 3]

sorted_list.insert(5)
assert sorted_list == [2, 2, 3, 5]

sorted_list.insert(4)
assert sorted_list == [2, 2, 3, 4, 5]

with pytest.raises(ValueError):
sorted_list.remove(999)

sorted_list.remove(3)
assert sorted_list == [2, 2, 4, 5]

sorted_list.remove(2)
assert sorted_list == [2, 4, 5]

sorted_list.remove(5)
assert sorted_list == [2, 4]

sorted_list.remove(2)
assert sorted_list == [4]

sorted_list.remove(4)
assert sorted_list == []
5 changes: 3 additions & 2 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,14 @@ def test_rolling_std(array, window_size, window_type):
assert pytest.approx(list(got), nan_ok=True) == list(expected)


@pytest.mark.parametrize("tracker", ["skiplist", "sortedlist"])
@pytest.mark.parametrize(
"array", [[3, 0, 1, 7, 2], [3, -8, 1, 7, -2, 8, 1, -7, -2, 9, 3], [1], []]
)
@pytest.mark.parametrize("window_size", [1, 2, 3, 4, 5, 6])
@pytest.mark.parametrize("window_type", ["fixed", "variable"])
def test_rolling_median(array, window_size, window_type):
got = Median(array, window_size, window_type=window_type)
def test_rolling_median(array, window_size, window_type, tracker):
got = Median(array, window_size, window_type=window_type, tracker=tracker)
expected = Apply(array, window_size, operation=_median, window_type=window_type)
assert pytest.approx(list(got)) == list(expected)

Expand Down