diff --git a/rx/core/observable/zip.py b/rx/core/observable/zip.py index 8b7564771..f0d0c1a14 100644 --- a/rx/core/observable/zip.py +++ b/rx/core/observable/zip.py @@ -34,6 +34,7 @@ def subscribe(observer: typing.Observer, n = len(sources) queues: List[List] = [[] for _ in range(n)] lock = RLock() + is_completed = [False] * n @synchronized(lock) def next(i): @@ -46,6 +47,14 @@ def next(i): return observer.on_next(res) + elif all([x for j, x in enumerate(is_completed) if j != i]) \ + and all([len(x) == 0 for j, x in enumerate(queues) if j != i]): + observer.on_completed() + + def completed(i): + is_completed[i] = True + if all(is_completed) or all([len(q) == 0 for q in queues]): + observer.on_completed() subscriptions = [None] * n @@ -58,7 +67,7 @@ def on_next(x): queues[i].append(x) next(i) - sad.disposable = source.subscribe_(on_next, observer.on_error, observer.on_completed, scheduler) + sad.disposable = source.subscribe_(on_next, observer.on_error, lambda: completed(i), scheduler) subscriptions[i] = sad for idx in range(n): diff --git a/tests/test_observable/test_zip.py b/tests/test_observable/test_zip.py index fe09b4616..a14a79831 100644 --- a/tests/test_observable/test_zip.py +++ b/tests/test_observable/test_zip.py @@ -95,7 +95,7 @@ def create(): ops.map(sum)) results = scheduler.start(create) - assert results.messages == [on_completed(220)] + assert results.messages == [] def test_zip_non_empty_never(self): scheduler = TestScheduler() @@ -109,7 +109,7 @@ def create(): ops.map(sum)) results = scheduler.start(create) - assert results.messages == [on_completed(220)] + assert results.messages == [] def test_zip_non_empty_non_empty(self): scheduler = TestScheduler() @@ -126,6 +126,36 @@ def create(): results = scheduler.start(create) assert results.messages == [on_next(220, 2 + 3), on_completed(230)] + def test_zip_non_empty_non_empty_sequential(self): + scheduler = TestScheduler() + msgs1 = [on_next(210, 1), on_next(215, 2), on_completed(230)] + msgs2 = [on_next(240, 1), on_next(245, 3), on_completed(250)] + e1 = scheduler.create_cold_observable(msgs1) + e2 = scheduler.create_cold_observable(msgs2) + + def create(): + return e1.pipe( + ops.zip(e2), + ops.map(sum)) + + results = scheduler.start(create) + assert results.messages == [on_next(200+240, 1 + 1), on_next(200+245, 2 + 3), on_completed(200+250)] + + def test_zip_non_empty_partial_sequential(self): + scheduler = TestScheduler() + msgs1 = [on_next(210, 1), on_next(215, 2), on_completed(230)] + msgs2 = [on_next(240, 1), on_completed(250)] + e1 = scheduler.create_cold_observable(msgs1) + e2 = scheduler.create_cold_observable(msgs2) + + def create(): + return e1.pipe( + ops.zip(e2), + ops.map(sum)) + + results = scheduler.start(create) + assert results.messages == [on_next(200+240, 1 + 1), on_completed(200+250)] + def test_zip_empty_error(self): ex = 'ex' scheduler = TestScheduler()