Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/#41/stream not closing after terminal operation #42

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions pystreamapi/_streams/__base_stream.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,50 @@
# pylint: disable=protected-access
from __future__ import annotations
import functools
import itertools
from abc import abstractmethod
from builtins import reversed
from functools import cmp_to_key
from typing import Iterable, Callable, Any, TypeVar, Iterator
from typing import Iterable, Callable, Any, TypeVar, Iterator, TYPE_CHECKING

from pystreamapi.__optional import Optional
from pystreamapi._itertools.tools import dropwhile
from pystreamapi._lazy.process import Process
from pystreamapi._lazy.queue import ProcessQueue
from pystreamapi._streams.error.__error import ErrorHandler
from pystreamapi._itertools.tools import dropwhile
if TYPE_CHECKING:
from pystreamapi._streams.numeric.__numeric_base_stream import NumericBaseStream

K = TypeVar('K')
_V = TypeVar('_V')
_identity_missing = object()


def _operation(func):
"""
Decorator to execute all the processes in the queue before executing the decorated function.
To be applied to intermediate operations.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
self: BaseStream = args[0]
self._verify_open()
return func(*args, **kwargs)

return wrapper


def terminal(func):
"""
Decorator to execute all the processes in the queue before executing the decorated function.
To be applied to terminal operations.
"""
@functools.wraps(func)
@_operation
def wrapper(*args, **kwargs):
self: BaseStream = args[0]
# pylint: disable=protected-access
self._queue.execute_all()
self._close()
return func(*args, **kwargs)

return wrapper
Expand All @@ -47,6 +66,16 @@ class BaseStream(Iterable[K], ErrorHandler):
def __init__(self, source: Iterable[K]):
self._source = source
self._queue = ProcessQueue()
self._open = True

def _close(self):
"""Close the stream."""
self._open = False

def _verify_open(self):
"""Verify if stream is open. If not, raise an exception."""
if not self._open:
raise RuntimeError("The stream has been closed")

@terminal
def __iter__(self) -> Iterator[K]:
Expand All @@ -63,6 +92,7 @@ def concat(cls, *streams: "BaseStream[K]"):
"""
return cls(itertools.chain(*list(streams)))

@_operation
def distinct(self) -> 'BaseStream[_V]':
"""Returns a stream consisting of the distinct elements of this stream."""
self._queue.append(Process(self.__distinct))
Expand All @@ -72,6 +102,7 @@ def __distinct(self):
"""Removes duplicate elements from the stream."""
self._source = list(set(self._source))

@_operation
def drop_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[_V]':
"""
Returns, if this stream is ordered, a stream consisting of the remaining elements of this
Expand All @@ -86,6 +117,7 @@ def __drop_while(self, predicate: Callable[[Any], bool]):
"""Drops elements from the stream while the predicate is true."""
self._source = list(dropwhile(predicate, self._source, self))

@_operation
def filter(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]':
"""
Returns a stream consisting of the elements of this stream that match the given predicate.
Expand All @@ -99,6 +131,7 @@ def filter(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]':
def _filter(self, predicate: Callable[[K], bool]):
"""Implementation of filter. Should be implemented by subclasses."""

@_operation
def flat_map(self, predicate: Callable[[K], Iterable[_V]]) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the results of replacing each element of this stream with
Expand All @@ -114,6 +147,7 @@ def flat_map(self, predicate: Callable[[K], Iterable[_V]]) -> 'BaseStream[_V]':
def _flat_map(self, predicate: Callable[[K], Iterable[_V]]):
"""Implementation of flat_map. Should be implemented by subclasses."""

@_operation
def group_by(self, key_mapper: Callable[[K], Any]) -> 'BaseStream[K]':
"""
Returns a Stream consisting of the results of grouping the elements of this stream
Expand All @@ -133,6 +167,7 @@ def __group_by(self, key_mapper: Callable[[Any], Any]):
def _group_to_dict(self, key_mapper: Callable[[K], Any]) -> dict[K, list]:
"""Groups the stream into a dictionary. Should be implemented by subclasses."""

@_operation
def limit(self, max_size: int) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the elements of this stream, truncated to be no longer
Expand All @@ -147,6 +182,7 @@ def __limit(self, max_size: int):
"""Limits the stream to the first n elements."""
self._source = itertools.islice(self._source, max_size)

@_operation
def map(self, mapper: Callable[[K], _V]) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the results of applying the given function to the elements
Expand All @@ -161,18 +197,20 @@ def map(self, mapper: Callable[[K], _V]) -> 'BaseStream[_V]':
def _map(self, mapper: Callable[[K], _V]):
"""Implementation of map. Should be implemented by subclasses."""

def map_to_int(self) -> 'BaseStream[_V]':
@_operation
def map_to_int(self) -> NumericBaseStream[_V]:
"""
Returns a stream consisting of the results of converting the elements of this stream to
integers.
"""
self._queue.append(Process(self.__map_to_int))
return self
return self._to_numeric_stream()

def __map_to_int(self):
"""Converts the stream to integers."""
self._map(int)

@_operation
def map_to_str(self) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the results of converting the elements of this stream to
Expand All @@ -185,6 +223,7 @@ def __map_to_str(self):
"""Converts the stream to strings."""
self._map(str)

@_operation
def peek(self, action: Callable) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the elements of this stream, additionally performing the
Expand All @@ -196,9 +235,11 @@ def peek(self, action: Callable) -> 'BaseStream[_V]':
return self

@abstractmethod
@_operation
def _peek(self, action: Callable):
"""Implementation of peek. Should be implemented by subclasses."""

@_operation
def reversed(self) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the elements of this stream, with their order being
Expand All @@ -214,6 +255,7 @@ def __reversed(self):
except TypeError:
self._source = reversed(list(self._source))

@_operation
def skip(self, n: int) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the remaining elements of this stream after discarding the
Expand All @@ -228,6 +270,7 @@ def __skip(self, n: int):
"""Skips the first n elements of the stream."""
self._source = self._source[n:]

@_operation
def sorted(self, comparator: Callable[[K], int] = None) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the elements of this stream, sorted according to natural
Expand All @@ -243,6 +286,7 @@ def __sorted(self, comparator: Callable[[K], int] = None):
else:
self._source = sorted(self._source, key=cmp_to_key(comparator))

@_operation
def take_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[_V]':
"""
Returns, if this stream is ordered, a stream consisting of the longest prefix of elements
Expand All @@ -257,8 +301,6 @@ def __take_while(self, predicate: Callable[[Any], bool]):
"""Takes elements from the stream while the predicate is true."""
self._source = list(itertools.takewhile(predicate, self._source))

# Terminal Operations:

@abstractmethod
@terminal
def all_match(self, predicate: Callable[[K], bool]):
Expand Down Expand Up @@ -373,3 +415,7 @@ def to_dict(self, key_mapper: Callable[[K], Any]) -> dict:

:param key_mapper:
"""

@abstractmethod
def _to_numeric_stream(self) -> NumericBaseStream[_V]:
"""Converts a stream to a numeric stream. To be implemented by subclasses."""
6 changes: 6 additions & 0 deletions pystreamapi/_streams/__parallel_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,9 @@ def _set_parallelizer_src(self):

def __mapper(self, mapper):
return lambda x: self._one(mapper=mapper, item=x)

def _to_numeric_stream(self):
# pylint: disable=import-outside-toplevel
from pystreamapi._streams.numeric.__parallel_numeric_stream import ParallelNumericStream
self.__class__ = ParallelNumericStream
return self
6 changes: 6 additions & 0 deletions pystreamapi/_streams/__sequential_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,9 @@ def reduce(self, predicate: Callable, identity=_identity_missing, depends_on_sta
@stream.terminal
def to_dict(self, key_mapper: Callable[[Any], Any]) -> dict:
return self._group_to_dict(key_mapper)

def _to_numeric_stream(self):
# pylint: disable=import-outside-toplevel
from pystreamapi._streams.numeric.__sequential_numeric_stream import SequentialNumericStream
self.__class__ = SequentialNumericStream
return self
16 changes: 14 additions & 2 deletions pystreamapi/_streams/numeric/__numeric_base_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,22 @@ def interquartile_range(self) -> Union[float, int, None]:
Calculates the iterquartile range of a numerical Stream
:return: The iterquartile range, can be int or float
"""
return self.third_quartile() - self.first_quartile() if len(self._source) > 0 else None
return self._interquartile_range()

def _interquartile_range(self):
"""Implementation of the interquartile range calculation"""
return self._third_quartile() - self._first_quartile() if len(self._source) > 0 else None

@terminal
def first_quartile(self) -> Union[float, int, None]:
"""
Calculates the first quartile of a numerical Stream
:return: The first quartile, can be int or float
"""
return self._first_quartile()

def _first_quartile(self):
"""Implementation of the first quartile calculation"""
self._source = sorted(self._source)
return self.__median(self._source[:(len(self._source)) // 2])

Expand Down Expand Up @@ -59,7 +67,7 @@ def __median(source) -> Union[float, int, None]:
@terminal
def mode(self) -> Union[list[Union[int, float]], None]:
"""
Calculates the mode(s) (most frequently occurring element) of a numerical Stream
Calculates the mode/modes (most frequently occurring element/elements) of a numerical Stream
:return: The mode, can be int or float
"""
frequency = Counter(self._source)
Expand Down Expand Up @@ -90,5 +98,9 @@ def third_quartile(self) -> Union[float, int, None]:
Calculates the third quartile of a numerical Stream
:return: The third quartile, can be int or float
"""
return self._third_quartile()

def _third_quartile(self):
"""Implementation of the third quartile calculation"""
self._source = sorted(self._source)
return self.__median(self._source[(len(self._source) + 1) // 2:])
1 change: 0 additions & 1 deletion pystreamapi/_streams/numeric/__parallel_numeric_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def sum(self) -> Union[float, int, None]:
_sum = self.__sum()
return 0 if _sum == [] else _sum

@terminal
def __sum(self):
"""Parallel sum method"""
self._set_parallelizer_src()
Expand Down
108 changes: 108 additions & 0 deletions tests/test_stream_closed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import unittest

from parameterized import parameterized_class

from pystreamapi._streams.__parallel_stream import ParallelStream
from pystreamapi._streams.__sequential_stream import SequentialStream
from pystreamapi._streams.numeric.__parallel_numeric_stream import ParallelNumericStream
from pystreamapi._streams.numeric.__sequential_numeric_stream import SequentialNumericStream


@parameterized_class("stream", [
[SequentialStream],
[ParallelStream],
[SequentialNumericStream],
[ParallelNumericStream]])
class BaseStreamClosed(unittest.TestCase):
def test_closed_stream_throws_exception(self):
# pylint: disable=too-many-statements
closed_stream = self.stream([])
closed_stream.for_each(lambda _: ...)

# Verify that all methods throw a RuntimeError
with self.assertRaises(RuntimeError):
list(closed_stream)

with self.assertRaises(RuntimeError):
closed_stream.distinct()

with self.assertRaises(RuntimeError):
closed_stream.drop_while(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.filter(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.flat_map(lambda x: [x])

with self.assertRaises(RuntimeError):
closed_stream.group_by(lambda x: x)

with self.assertRaises(RuntimeError):
closed_stream.limit(5)

with self.assertRaises(RuntimeError):
closed_stream.map(lambda x: x)

with self.assertRaises(RuntimeError):
closed_stream.map_to_int()

with self.assertRaises(RuntimeError):
closed_stream.map_to_str()

with self.assertRaises(RuntimeError):
closed_stream.peek(lambda x: None)

with self.assertRaises(RuntimeError):
closed_stream.reversed()

with self.assertRaises(RuntimeError):
closed_stream.skip(5)

with self.assertRaises(RuntimeError):
closed_stream.sorted()

with self.assertRaises(RuntimeError):
closed_stream.take_while(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.all_match(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.any_match(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.count()

with self.assertRaises(RuntimeError):
closed_stream.find_any()

with self.assertRaises(RuntimeError):
closed_stream.find_first()

with self.assertRaises(RuntimeError):
closed_stream.for_each(lambda x: None)

with self.assertRaises(RuntimeError):
closed_stream.none_match(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.min()

with self.assertRaises(RuntimeError):
closed_stream.max()

with self.assertRaises(RuntimeError):
closed_stream.reduce(lambda x, y: x + y)

with self.assertRaises(RuntimeError):
closed_stream.to_list()

with self.assertRaises(RuntimeError):
closed_stream.to_tuple()

with self.assertRaises(RuntimeError):
closed_stream.to_set()

with self.assertRaises(RuntimeError):
closed_stream.to_dict(lambda x: x)
Loading
Loading