Skip to content

Commit

Permalink
add test cases for multicast reduce operator
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSchneeberger committed Dec 19, 2019
1 parent f537b49 commit 5091042
Show file tree
Hide file tree
Showing 22 changed files with 850 additions and 186 deletions.
1 change: 0 additions & 1 deletion examples/multicast/mergeexample.py
Expand Up @@ -5,7 +5,6 @@


result = rxbp.multicast.merge(m1, m2).pipe(
# rxbp.multicast.op.debug('d1'),
rxbp.multicast.op.reduce(),
).to_flowable().run()

Expand Down
1 change: 0 additions & 1 deletion examples/multicast/reduceexample.py
Expand Up @@ -22,7 +22,6 @@
rxbp.multicast.op.merge(
rxbp.multicast.from_flowable(base2)
),
# rxbp.multicast.op.debug('d1'),
rxbp.multicast.op.reduce(),
rxbp.multicast.op.map(lambda v: v['val1'].zip(v['val2'])),
).to_flowable().run()
Expand Down
7 changes: 5 additions & 2 deletions rxbp/multicast/collectablemulticasts/collectablemulticast.py
Expand Up @@ -74,8 +74,11 @@ def map(self, func: Callable[[MultiCastValue], MultiCastValue]):
def pipe(self, *operators: MultiCastOperator) -> 'CollectableMultiCast':
return reduce(lambda acc, op: op(acc), operators, self)

def reduce(self):
main = self._main.reduce()
def reduce(
self,
maintain_order: bool = None,
):
main = self._main.reduce(maintain_order=maintain_order)
return CollectableMultiCast(main=main, collected=self._collected)

def share(self):
Expand Down
@@ -1,12 +1,12 @@
from typing import Any, Callable

from rxbp.flowablebase import FlowableBase
from rxbp.multicast.observables.flatmapnobackpressureobservable import FlatMapNoBackpressureObservable
from rxbp.multicast.observables.flatconcatnobackpressureobservable import FlatConcatNoBackpressureObservable
from rxbp.subscriber import Subscriber
from rxbp.subscription import Subscription, SubscriptionInfo


class FlatMapNoBackpressureFlowable(FlowableBase):
class FlatConcatNoBackpressureFlowable(FlowableBase):
def __init__(self, source: FlowableBase, selector: Callable[[Any], FlowableBase]):
super().__init__()

Expand All @@ -21,8 +21,8 @@ def observable_selector(elem: Any):
subscription = flowable.unsafe_subscribe(subscriber=subscriber)
return subscription.observable

observable = FlatMapNoBackpressureObservable(source=subscription.observable, selector=observable_selector,
scheduler=subscriber.scheduler, subscribe_scheduler=subscriber.subscribe_scheduler)
observable = FlatConcatNoBackpressureObservable(source=subscription.observable, selector=observable_selector,
scheduler=subscriber.scheduler, subscribe_scheduler=subscriber.subscribe_scheduler)

# base becomes undefined after flat mapping
base = None
Expand Down
7 changes: 5 additions & 2 deletions rxbp/multicast/multicast.py
Expand Up @@ -107,8 +107,11 @@ def get_source(_, info: MultiCastInfo) -> rx.typing.Observable:
def map(self, func: Callable[[MultiCastValue], MultiCastValue]):
return MultiCast(MapMultiCast(source=self, func=func))

def reduce(self):
return MultiCast(ReduceMultiCast(source=self))
def reduce(
self,
maintain_order: bool = None,
):
return MultiCast(ReduceMultiCast(source=self, maintain_order=maintain_order))

def share(self):
subject = Subject()
Expand Down
5 changes: 4 additions & 1 deletion rxbp/multicast/multicastopmixin.py
Expand Up @@ -61,7 +61,10 @@ def map(self, func: Callable[[MultiCastValue], MultiCastValue]):
...

@abstractmethod
def reduce(self):
def reduce(
self,
maintain_order: bool = None,
):
...

@abstractmethod
Expand Down
67 changes: 40 additions & 27 deletions rxbp/multicast/multicasts/defermulticast.py
Expand Up @@ -54,7 +54,14 @@ def unsafe_subscribe(self, subscriber: Subscriber) -> Subscription:
start = StartWithInitialValueFlowable()
shared = RefCountFlowable(source=start)

# mutual curr_index variable is used to index the input Flowables
# as the sequal to the deferred Flowables
# {0: deferred_flowable_1,
# 1: input_flowable_1,
# 2: input_flowable_2,}
curr_index = 0

# map initial value(s) to a dictionary
if isinstance(initial, list):
curr_index = len(initial)
initial_dict = {idx: val for idx, val in enumerate(initial)}
Expand Down Expand Up @@ -102,51 +109,57 @@ def map_to_flowable_dict(base: MultiCastBase):
output = self.func(init)

def map_func(base: MultiCastValue):

def select_first_index(state):
class SingleFlowableDict(SingleFlowableMixin, FlowableDict):
def get_single_flowable(self) -> Flowable:
return state[0]

return SingleFlowableDict(state)

def select_none(state):
return FlowableDict(state)

match_error_message = f'defer function returned "{base}" which does not match initial "{initial}"'

if isinstance(base, Flowable) and len(initial_dict) == 1:

# if initial would be a dicitonary, then the input to the defer operator
# must be a dicitonary and not a Flowable.
assert not isinstance(initial, dict), match_error_message

if isinstance(initial, list):
assert len(base) == 1, match_error_message
# create standard form
flowable_state = {0: base}

deferred_values = {list(initial_dict.keys())[0]: base} # deferred values refer to the values returned by the defer function
select_flowable_dict = select_none
# function that will map the resulting state back to a Flowable
from_state = lambda state: state[0]

elif isinstance(base, list):
assert isinstance(initial, list) and len(initial) == len(base)

deferred_values = {idx: val for idx, val in enumerate(base)}
select_flowable_dict = select_first_index
# create standard form
flowable_state = {idx: val for idx, val in enumerate(base)}

# def select_first_index(state):
# class SingleFlowableDict(SingleFlowableMixin, FlowableDict):
# def get_single_flowable(self) -> Flowable:
# return state[0]
#
# return SingleFlowableDict(state)

# function that will map the resulting state to a list
from_state = lambda state: list(state.values())

elif isinstance(base, dict) or isinstance(base, FlowableStateMixin):
if isinstance(base, FlowableStateMixin):
deferred_values = base.get_flowable_state()
flowable_state = base.get_flowable_state()
else:
deferred_values = base
flowable_state = base

match_error_message = f'defer function returned "{deferred_values.keys()}", ' \
match_error_message = f'defer function returned "{flowable_state.keys()}", ' \
f'which does not match initial "{initial.keys()}"'

assert isinstance(initial, dict) and set(initial.keys()) <= set(deferred_values.keys()), match_error_message
assert isinstance(initial, dict) and set(initial.keys()) <= set(flowable_state.keys()), match_error_message

select_flowable_dict = select_none
# function that will map the resulting state to a FlowableDict
def select_none(state):
return FlowableDict(state)
from_state = select_none

else:
raise Exception(f'illegal case "{base}"')

shared_deferred_values = {key: RefCountFlowable(value) for key, value in deferred_values.items()}
# share flowables
shared_flowable_state = {key: RefCountFlowable(value) for key, value in flowable_state.items()}

lock = threading.RLock()
is_first = [True]
Expand All @@ -164,14 +177,14 @@ def unsafe_subscribe(self, subscriber: Subscriber) -> Subscription:
is_first[0] = False
close_loop = True

# close defer loop only if first element has received
# close defer loop only if first subscribed
if close_loop:

def gen_index_for_each_deferred_state():
""" for each value returned by the defer function """
for key in initial_dict.keys():
def for_func(key=key):
return Flowable(MapFlowable(shared_deferred_values[key], selector=lambda v: (key, v)))
return Flowable(MapFlowable(shared_flowable_state[key], selector=lambda v: (key, v)))
yield for_func()
indexed_deferred_values = gen_index_for_each_deferred_state()

Expand Down Expand Up @@ -231,9 +244,9 @@ def observe(self, observer_info: ObserverInfo):
return Subscription(info=SubscriptionInfo(None), observable=defer_observable)

# create a flowable for all deferred values
new_states = {k: Flowable(DeferFlowable(v, k)) for k, v in shared_deferred_values.items()}
new_states = {k: Flowable(DeferFlowable(v, k)) for k, v in shared_flowable_state.items()}

return select_flowable_dict(new_states)
return from_state(new_states)

return output.get_source(info=info).pipe(
rxop.map(map_func),
Expand Down
132 changes: 70 additions & 62 deletions rxbp/multicast/multicasts/reducemulticast.py
Expand Up @@ -7,7 +7,7 @@
from rxbp.flowable import Flowable
from rxbp.flowables.refcountflowable import RefCountFlowable
from rxbp.multicast.flowables.connectableflowable import ConnectableFlowable
from rxbp.multicast.flowables.flatmapnobackpressureflowable import FlatMapNoBackpressureFlowable
from rxbp.multicast.flowables.flatconcatnobackpressureflowable import FlatConcatNoBackpressureFlowable
from rxbp.multicast.flowables.flatmergenobackpressureflowable import FlatMergeNoBackpressureFlowable
from rxbp.multicast.flowablestatemixin import FlowableStateMixin
from rxbp.multicast.multicastInfo import MultiCastInfo
Expand All @@ -19,8 +19,13 @@


class ReduceMultiCast(MultiCastBase):
def __init__(self, source: MultiCastBase):
def __init__(
self,
source: MultiCastBase,
maintain_order: bool = None,
):
self.source = source
self.maintain_order = maintain_order

def get_source(self, info: MultiCastInfo):
source = self.source.get_source(info=info).pipe(
Expand All @@ -30,7 +35,7 @@ def get_source(self, info: MultiCastInfo):
or isinstance(v, list)),
)

def func(first: Union[FlowableStateMixin, dict], lifted_obs: Observable):
def func(lifted_obs: Observable, first: Union[FlowableStateMixin, dict]):
if isinstance(first, dict):
to_state = lambda s: s
from_state = lambda s: s
Expand All @@ -48,70 +53,73 @@ def func(first: Union[FlowableStateMixin, dict], lifted_obs: Observable):

first_state = to_state(first)

class ReduceObservable(Observable):
def __init__(self, first: FlowableStateMixin):
super().__init__()

self.first = first

def _subscribe_core(
self,
observer: rx.typing.Observer,
scheduler: Optional[rx.typing.Scheduler] = None
) -> rx.typing.Disposable:
# # share only if more than one elem in first state
# shared_source = lifted_obs.pipe(
# rxop.share(),
# )

# lifted_flowable = rxbp.from_rx(lifted_obs)

conn_observer = ConnectableObserver(
underlying=None,
scheduler=info.multicast_scheduler,
subscribe_scheduler=info.multicast_scheduler,
)

# subscribe to source rx.Observables immediately
source_flowable = rxbp.from_rx(lifted_obs)
subscriber = Subscriber(
scheduler=info.multicast_scheduler,
subscribe_scheduler=info.multicast_scheduler,
)
subscription = source_flowable.unsafe_subscribe(subscriber=subscriber)
subscription.observable.observe(ObserverInfo(conn_observer))

conn_flowable = ConnectableFlowable(conn_observer=conn_observer)

if 1 < len(first_state):
shared_flowable = RefCountFlowable(conn_flowable)
else:
shared_flowable = conn_flowable

def gen_flowables():
for key in first_state.keys():
def for_func(key=key, shared_flowable=shared_flowable):
def selector(v: FlowableStateMixin):
flowable = to_state(v)[key]
return flowable
# class ReduceObservable(Observable):
# def __init__(
# self,
# first: FlowableStateMixin,
# maintain_order: bool = None,
# ):
# super().__init__()
#
# self.first = first
# self.maintain_order = maintain_order
#
# def _subscribe_core(
# self,
# observer: rx.typing.Observer,
# scheduler: Optional[rx.typing.Scheduler] = None
# ) -> rx.typing.Disposable:

conn_observer = ConnectableObserver(
underlying=None,
scheduler=info.multicast_scheduler,
subscribe_scheduler=info.multicast_scheduler,
)

# subscribe to source rx.Observables immediately
source_flowable = rxbp.from_rx(lifted_obs)
subscriber = Subscriber(
scheduler=info.multicast_scheduler,
subscribe_scheduler=info.multicast_scheduler,
)
subscription = source_flowable.unsafe_subscribe(subscriber=subscriber)
subscription.observable.observe(ObserverInfo(conn_observer))

conn_flowable = ConnectableFlowable(conn_observer=conn_observer)

if 1 < len(first_state):
shared_flowable = RefCountFlowable(conn_flowable)
else:
shared_flowable = conn_flowable

flattened_flowable = FlatMergeNoBackpressureFlowable(shared_flowable, selector)
# flattened_flowable = FlatMapNoBackpressureFlowable(shared_flowable, selector)
result = RefCountFlowable(flattened_flowable)
def gen_flowables():
for key in first_state.keys():
def for_func(key=key, shared_flowable=shared_flowable):
def selector(v: FlowableStateMixin):
flowable = to_state(v)[key]
return flowable

return key, Flowable(result)
if self.maintain_order:
flattened_flowable = FlatConcatNoBackpressureFlowable(shared_flowable, selector)
else:
flattened_flowable = FlatMergeNoBackpressureFlowable(shared_flowable, selector)

yield for_func()
result = RefCountFlowable(flattened_flowable)
flowable = Flowable(result)
return key, flowable

result_flowables = dict(gen_flowables())
result = from_state(result_flowables)
yield for_func()

# def action(_, __):
observer.on_next(result)
observer.on_completed()
result_flowables = dict(gen_flowables())
result = from_state(result_flowables)
return result

# info.multicast_scheduler.schedule(action)
# observer.on_next(result)
# observer.on_completed()

return ReduceObservable(first=first)
# return ReduceObservable(
# first=first,
# maintain_order=self.maintain_order,
# )

return LiftObservable(source=source, func=func, subscribe_scheduler=info.multicast_scheduler)
return LiftObservable(source=source, func=func, scheduler=info.multicast_scheduler)
4 changes: 2 additions & 2 deletions rxbp/multicast/multicasts/zipmulticast.py
Expand Up @@ -5,7 +5,7 @@
from rxbp.flowable import Flowable
from rxbp.flowables.refcountflowable import RefCountFlowable
from rxbp.multicast.flowables.connectableflowable import ConnectableFlowable
from rxbp.multicast.flowables.flatmapnobackpressureflowable import FlatMapNoBackpressureFlowable
from rxbp.multicast.flowables.flatconcatnobackpressureflowable import FlatConcatNoBackpressureFlowable
from rxbp.multicast.multicastInfo import MultiCastInfo
from rxbp.multicast.multicastbase import MultiCastBase
from rxbp.multicast.typing import MultiCastValue
Expand Down Expand Up @@ -57,7 +57,7 @@ def for_func(source=source):

conn_flowable = ConnectableFlowable(conn_observer=conn_observer)

flattened_flowable = FlatMapNoBackpressureFlowable(conn_flowable, to_flowable)
flattened_flowable = FlatConcatNoBackpressureFlowable(conn_flowable, to_flowable)

ref_count_flowable = RefCountFlowable(flattened_flowable)

Expand Down

0 comments on commit 5091042

Please sign in to comment.