diff --git a/rx/core/operators/flatmap.py b/rx/core/operators/flatmap.py index 0f6b73d89..f02626589 100644 --- a/rx/core/operators/flatmap.py +++ b/rx/core/operators/flatmap.py @@ -10,11 +10,12 @@ def _flat_map_internal(source, mapper=None, mapper_indexed=None): def projection(x, i): mapper_result = mapper(x) if mapper else mapper_indexed(x, i) - if isinstance(mapper_result, collections.abc.Iterable): + if is_future(mapper_result): + result = from_future(mapper_result) + elif isinstance(mapper_result, collections.abc.Iterable): result = from_(mapper_result) else: - result = from_future(mapper_result) if is_future( - mapper_result) else mapper_result + result = mapper_result return result return source.pipe( diff --git a/tests/test_observable/test_flatmap_async.py b/tests/test_observable/test_flatmap_async.py new file mode 100644 index 000000000..f568638e6 --- /dev/null +++ b/tests/test_observable/test_flatmap_async.py @@ -0,0 +1,33 @@ +import unittest +import asyncio +from rx import operators as ops +from rx.subject import Subject + +from rx.scheduler.eventloop import AsyncIOScheduler + + +class TestFlatMapAsync(unittest.TestCase): + + def test_flat_map_async(self): + actual_next = None + loop = asyncio.get_event_loop() + scheduler = AsyncIOScheduler(loop=loop) + + def mapper(i): + async def _mapper(i): + return i + 1 + + return asyncio.ensure_future(_mapper(i)) + + def on_next(i): + nonlocal actual_next + actual_next = i + + async def test_flat_map(): + x = Subject() + x.pipe(ops.flat_map(mapper)).subscribe(on_next, scheduler=scheduler) + x.on_next(10) + await asyncio.sleep(0.1) + + loop.run_until_complete(test_flat_map()) + assert actual_next == 11