# Stream

> Stream data structure.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| default_exp streams

In [None]:
#| export

import asyncio
import functools
import abc
import enum
from typing import AsyncIterable, AsyncIterator, Iterable, TypeVar, Generic


In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
from fastcore.test import *

## Stream

In [None]:
#| export

_T = TypeVar("T")


class StreamStatus(enum.Enum):
  OK = enum.auto()
  SHUTDOWN = enum.auto()


class Stream(abc.ABC, AsyncIterator[_T]):

  @abc.abstractmethod
  async def next(
      self,
      with_status: bool = False,
  ) -> _T | None:
    pass

  async def __anext__(self) -> _T:
    e, status = await self.next(with_status=True)
    if status == StreamStatus.SHUTDOWN:
      raise StopAsyncIteration
    return e

  def __aiter__(self) -> AsyncIterator[_T]:
    return self

## StreamWriter

In [None]:
#| export


class StreamWriter(abc.ABC, Generic[_T]):

  @abc.abstractmethod
  async def put(self, *items: _T):
    pass

  @abc.abstractmethod
  async def shutdown(self):
    pass

  @abc.abstractmethod
  def readonly(self) -> Stream[_T]:
    pass

In [None]:
#| export


class InMemStreamWriter(StreamWriter[_T]):

  def __init__(self):
    self._q = asyncio.Queue()
    self._lock = asyncio.Lock()

  async def put(self, *items: _T):
    async with self._lock:
      # Lock to ensure all elements are enqueued without being shutdown.
      try:
        for item in items:
          await self._q.put(item)
      except asyncio.QueueShutDown:
        pass

  async def shutdown(self):
    async with self._lock:
      self._q.shutdown()

  def readonly(self) -> Stream[_T]:

    async def _next(w: InMemStreamWriter[_T],
                    *,
                    with_status: bool = False) -> _T | None:
      item, status = None, StreamStatus.OK
      try:
        item = await w._q.get()
        w._q.task_done()
      except asyncio.QueueShutDown:
        item, status = None, StreamStatus.SHUTDOWN

      if with_status:
        return item, status
      return item

    class _S(Stream[_T]):
      next = lambda _, *args, **kwargs: _next(self, *args, **kwargs)

    return _S()

### InMemStreamWriter tests

In [None]:
sw = InMemStreamWriter()
sr = sw.readonly()

_access_pattern = []  # P: Producer, C: Consumer.


async def producer():
  for i in range(3):
    global _access_pattern
    await asyncio.sleep(0.01)
    _access_pattern.append(("P", i))
    await sw.put(i)
  await sw.shutdown()


async def consumer():
  async for e in sr:
    _access_pattern.append(("C", e))


async with asyncio.TaskGroup() as tg:
  tg.create_task(consumer())
  tg.create_task(producer())

test_eq(_access_pattern, [("P", 0), ("C", 0), ("P", 1), ("C", 1), ("P", 2), ("C", 2)])

In [None]:
sw = InMemStreamWriter()
sr = sw.readonly()

await sw.put("a", "b")

# await sw.shutdown()
await sw.shutdown() # No-op.

test_eq(await sr.next(), "a")
test_eq(await sr.next(), "b")

await sw.put("c") # No-op.
test_eq(await sr.next(), None)
test_eq(await sr.next(with_status=True), (None, StreamStatus.SHUTDOWN))

## Stream Utils

### tolist

In [None]:
#| export


async def tolist(s: Stream[_T]) -> list[_T]:
  xs = []
  async for x in s:
    xs.append(x)
  return xs

### of

In [None]:
#| export


def of(*args: _T | AsyncIterable[_T] | Iterable[_T]) -> Stream[_T]:
  """Returns a Stream from the given source(s)."""

  class _FromIterableStream(Stream[_T]):

    def __init__(self, source: AsyncIterable[_T] | Iterable[_T]):
      if isinstance(source, AsyncIterable):
        self._iter = source.__aiter__()
      else:
        self._iter = self._to_aiter(source)

    async def next(
        self,
        with_status: bool = False,
    ) -> _T | None:
      try:
        item = await self._iter.__anext__()
        status = StreamStatus.OK
      except StopAsyncIteration:
        item, status = None, StreamStatus.SHUTDOWN

      if with_status:
        return item, status
      return item

    async def _to_aiter(self, iterable: Iterable[_T]) -> AsyncIterator[_T]:
      for item in iterable:
        # Simulate asynchronous behavior.
        await asyncio.sleep(0)
        yield item

  if len(args) == 1 and isinstance(args[0], (AsyncIterable, Iterable)):
    return _FromIterableStream(args[0])

  return _FromIterableStream(args)

#### of Tests

In [None]:
s = of(0, 1, 2)
test_eq(await s.next(), 0)
test_eq(await s.next(), 1)
test_eq(await s.next(), 2)
test_eq(await s.next(), None)

In [None]:
s = of(0, 1, 2)
test_eq(await s.next(with_status=True), (0, StreamStatus.OK))
test_eq(await s.next(with_status=True), (1, StreamStatus.OK))
test_eq(await s.next(with_status=True), (2, StreamStatus.OK))
test_eq(await s.next(with_status=True), (None, StreamStatus.SHUTDOWN))

In [None]:
s = of(range(3))
test_eq(await tolist(s), [0, 1, 2])

### concat

In [None]:
#| export


def concat(*streams: Stream[_T]) -> Stream[_T]:
  """Concatenates the given streams."""

  class _ConcatStream(Stream[_T]):

    def __init__(self):
      self._idx = 0

    async def next(
        self,
        with_status: bool = False,
    ) -> _T | None:
      while self._idx < len(streams):
        cur_stream = streams[self._idx]
        item, status = await cur_stream.next(with_status=True)
        if status == StreamStatus.OK:
          if with_status:
            return item, StreamStatus.OK
          return item
        elif status == StreamStatus.SHUTDOWN:
          self._idx += 1
        else:
          assert False, f"Unexpected status: {status}"

      if with_status:
        return None, StreamStatus.SHUTDOWN
      return None

  return _ConcatStream()

#### concat Tests

In [None]:
s0 = of(0, 1)
s1 = of(2, 3)
s2 = of(4, 5)
s = concat(s0, s1, s2)

test_eq(await tolist(s), [0, 1, 2, 3, 4, 5])

In [None]:
sw0 = InMemStreamWriter()
sw1 = InMemStreamWriter()
sr = concat(sw0.readonly(), sw1.readonly())

producers_access_pattern = []

async def slow_producer(sw):
  for i in range(2):
    await asyncio.sleep(0.05)
    producers_access_pattern.append(("P0", i))
    await sw.put(i)
  await sw.shutdown()


async def fast_producer(sw):
  for i in ["a", "b"]:
    await asyncio.sleep(0.01)
    producers_access_pattern.append(("P1", i))
    await sw.put(i)
  await sw.shutdown()


async with asyncio.TaskGroup() as tg:
  t = tg.create_task(tolist(sr))
  tg.create_task(slow_producer(sw0))
  tg.create_task(fast_producer(sw1))
  consumed = await t

test_eq(consumed, [0, 1, "a", "b"])
test_eq(producers_access_pattern, [("P1", "a"), ("P1", "b"), ("P0", 0), ("P0", 1)])

### interleave

In [None]:
#| export


def interleave(*streams: Stream[_T]) -> Stream[_T]:
  w = InMemStreamWriter()

  async def consume(s):
    nonlocal w
    async for e in s:
      await w.put(e)

  ts = [asyncio.create_task(consume(s)) for s in streams]

  async def cleanup():
    nonlocal ts
    await asyncio.gather(*ts)
    await w.shutdown()

  asyncio.create_task(cleanup())

  return w.readonly()


#### interleave Tests

In [None]:
async def fast_producer(sw):
  for i in ("a", "b", "c"):
    await asyncio.sleep(0.01)
    await sw.put(i)
  await sw.shutdown()


async def slow_producer(sw):
  for i in ("x", "y"):
    await asyncio.sleep(0.016)
    await sw.put(i)
  await sw.shutdown()


sw0 = InMemStreamWriter()
sw1 = InMemStreamWriter()
sr = interleave(sw0.readonly(), sw1.readonly())

async with asyncio.TaskGroup() as tg:
  t = tg.create_task(tolist(sr))
  tg.create_task(fast_producer(sw0))
  tg.create_task(slow_producer(sw1))

  consumed = await t

test_eq(consumed, ["a", "x", "b", "c", "y"])

### flatten

In [None]:
#| export


def flatten(s: Stream[_T | Stream[_T]]) -> Stream[_T]:
  """Flattens one level nested stream."""

  async def consume(s):
    async for x in s:
      if isinstance(x, Stream):
        async for y in x:
          yield y
      else:
        yield x

  return of(consume(s))

#### flatten Tests

In [None]:
s =  flatten(of(0, of(1, 2), 3, of(4, 5)))
test_eq(await tolist(s) , list(range(6)))

In [None]:
s =  flatten(of([0], [1]))
test_eq(await tolist(s) , [[0], [1]]) # only flattens streams.

### streamify

In [None]:
#| export


def streamify(func) -> Stream[_T]:

  @functools.wraps(func)
  def wrapper(*args, **kwargs) -> Stream[_T]:
    sw = InMemStreamWriter()

    async def mk_stream():
      nonlocal sw
      try:
        if asyncio.iscoroutinefunction(func):
          result = await func(*args, **kwargs)
        else:
          result = func(*args, **kwargs)
        s = of(result)  # Handles also async and sync iterables.
        async for e in s:
          await sw.put(e)
      finally:
        await sw.shutdown()

    # Write to the stream in the background.
    # FIXME: Handle errors otherwise they are silently ignored.
    asyncio.create_task(mk_stream())
    return sw.readonly()

  return wrapper

#### streamify Tests

In [None]:
@streamify
def fn(x):
  return x

s  = fn(5)
test_eq(isinstance(s, Stream), True)
test_eq(await tolist(s), [5])

In [None]:
@streamify
def fn(*, n):
  return range(n)

s  = fn(n=5)
test_eq(isinstance(s, Stream), True)
test_eq(await tolist(s), [0, 1, 2, 3, 4])

In [None]:
@streamify
def fn(*, n):
  yield from range(n)

s  = fn(n=5)
test_eq(isinstance(s, Stream), True)
test_eq(await tolist(s), [0, 1, 2, 3, 4])

In [None]:
@streamify
async def fn(x):
  return x

s = fn(0)
test_eq(isinstance(s, Stream), True)
test_eq(await tolist(s), [0])

In [None]:
@streamify
async def fn(*, n):
  for i in range(n):
    yield i

s = fn(n=5)
test_eq(isinstance(s, Stream), True)
test_eq(await tolist(s), [0, 1, 2, 3, 4])

### map

In [None]:
#| export


def map(func, *streams) -> Stream[_T]:
  """Maps the given function over the given streams."""

  class _MappedStream(Stream[_T]):

    async def next(
        self,
        with_status: bool = False,
    ) -> _T | None:
      args = []
      for s in streams:
        e, status = await s.next(with_status=True)
        if status != StreamStatus.OK:
          return None, status
        args.append(e)

      if asyncio.iscoroutinefunction(func):
        result = await func(*args)
      else:
        result = func(*args)

      if with_status:
        return result, StreamStatus.OK
      return result

  return _MappedStream()

#### map Tests

In [None]:
s = map(lambda x: x + 1, of(0, 1, 2))
test_eq(await tolist(s), [1, 2, 3])

In [None]:
s = map(lambda x, y: x + y, of(0, 1, 2), of(3, 4, 5))
test_eq(await tolist(s), [3, 5, 7])

In [None]:
s = map(lambda x, y: x + y, of(0, 1), of(3, 4, 5))
test_eq(await tolist(s), [3, 5])

In [None]:
async def upper(s: str):
  await asyncio.sleep(0.01)
  return s.upper()

s = map(upper, of("a", "b", "c"))
test_eq(await tolist(s), ["A", "B", "C"])

## Export

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()