# Transforms

> Channel transformations.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| default_exp transforms

In [None]:
#| export

import abc
import asyncio
import time
import uuid
import inspect
import contextlib
import contextvars
import dataclasses
from typing import Any, Callable, ParamSpec, Protocol, Generic, TypeVar, Awaitable
import functools

from fastcore.basics import patch

import fastagent_hacking.streams as sx
import fastagent_hacking.channels as cx

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
from fastcore.test import *

## Transforms


In [None]:
#| export

_I = TypeVar('I')
_O = TypeVar('O')


class Transform(abc.ABC, Generic[_I, _O]):

  @abc.abstractmethod
  def __call__(self, chan: cx.Channel[_I]) -> cx.Channel[_O]:
    """Transforms the input channel into an output channel."""


In [None]:
#| export


# FIXME: Move this to a separate module.
def _print_task_errors(task: asyncio.Task):
  if task.exception():
    task.print_stack()
    print(f"Task failed with exception: {task.exception()}")


### ParDo Transform

In [None]:
#| export

import collections


class ParDo(Transform[_I, _O]):
  """Processes each element in the input channel using a user-defined function."""

  def __init__(self, fn):  # FIXME: type hint
    self._fn = fn

    # Maintains a mapping from a packet.tag to a list of stream cancellation functions.
    # When a cancellation packet is received, all tasks associated with the tag
    # should be cancelled.
    # FIXME: Terminated tasks are currently not removed from the map.
    self._cncls_map = collections.defaultdict(list[Callable])

    self._bg_tasks = set()

  def __call__(self, chan: cx.Channel[_I]) -> cx.Channel[_O]:
    main_stream = sx.InMemStreamWriter()
    side_stream = sx.InMemStreamWriter()

    async def proc(chan):
      try:
        async for p in chan:
          assert isinstance(p, cx.Packet)
          if self._is_passthrough(p):
            await side_stream.put(p)
            if p.packet_type == cx.PacketType.CANCELLATION_PACKET:
              # Cancel all tasks associated with the tag.
              cncl_tag = p.payload
              for cncl in self._cncls_map[cncl_tag]:
                cncl()
              del self._cncls_map[p.payload]
            continue

          s = self._proc_packet(p)
          await main_stream.put(s)
      finally:
        await main_stream.shutdown()
        await side_stream.shutdown()

    t = asyncio.create_task(proc(chan))
    self._bg_tasks.add(t)
    t.add_done_callback(_print_task_errors)
    t.add_done_callback(self._bg_tasks.discard)

    return cx.as_chan(
        sx.interleave(
            side_stream.readonly(),
            sx.flatten(main_stream.readonly()),
        ))

  def _proc_packet(self, p: cx.Packet[_I]) -> sx.Stream[cx.Packet[_O]]:
    assert p.packet_type == cx.PacketType.DATA
    s, cncl = sx.streamify(self._fn, return_shutdown_fn=True)(p.payload)
    for tag in p.tags:
      self._cncls_map[tag].append(cncl)
    return sx.map(
        lambda x: cx.Packet(
            payload=x,
            packet_type=cx.PacketType.DATA,
            parent_packet_id=p.packet_id,
            tags=p.tags,
        ),
        s,
    )

  def _is_passthrough(self, p: cx.Packet) -> bool:
    return p.packet_type != cx.PacketType.DATA

In [None]:
#| export


def as_transform(fn: Callable | Transform) -> Transform:
  """Converts a function of a single argument into a Transform object."""
  if isinstance(fn, Transform):
    return fn

  return ParDo(fn)

In [None]:
#| export


@patch
def __or__(
    self: Transform,
    other,
) -> Transform:
  t1, t2 = self, as_transform(other)

  class ComposedTransform(Transform):

    def __call__(self, chan: cx.Channel) -> cx.Channel:
      return t2(t1(chan))

  return ComposedTransform()


@patch
def __ror__(
    self: Transform,
    other,
) -> Transform:
  t2, t1 = self, as_transform(other)

  class ComposedTransform(Transform):

    def __call__(self, chan: cx.Channel) -> cx.Channel:
      return t2(t1(chan))

  return ComposedTransform()

In [None]:
def fake_packet(payload) -> cx.Packet:
  return cx.Packet(
      payload=payload,
      packet_type=cx.PacketType.DATA,
      packet_id='1',
      parent_packet_id='1',
      created_at=1,
  )

In [None]:
from typing import Iterable

In [None]:
def cmp_packet_payloads(p1s: Iterable[cx.Packet], p2s: Iterable[cx.Packet]):
  p1s = list(map(lambda x: x.payload, p1s))
  p2s = list(map(lambda x: x.payload, p2s))
  return p1s == p2s

In [None]:
async def fn(x):
  await asyncio.sleep(0.1 * x)
  return x + 1

s = sx.of(
    fake_packet(1),
    fake_packet(2),
    fake_packet(3),
)
sink = cx.as_chan(s)

t = ParDo(fn)

start = time.monotonic()
got = await sx.tolist(t(sink))
end = time.monotonic()

test(
  got,
  [fake_packet(2), fake_packet(3), fake_packet(4)],
  cmp=cmp_packet_payloads,
)

test_close(end - start, 0.3, eps=0.01)

In [None]:
async def add1(x):
  await asyncio.sleep(0.1 * x)
  return x + 1

async def mul2(x):
  await asyncio.sleep(0.1 * x)
  return x * 2

s = sx.of(
    fake_packet(1),
    fake_packet(2),
    fake_packet(2),
    fake_packet(2),
    fake_packet(2),
    fake_packet(3),
)
sink = cx.as_chan(s)

t = ParDo(add1) | ParDo(mul2)

start = time.monotonic()
got = await sx.tolist(t(sink))
end = time.monotonic()

test(
  got,
  [fake_packet(4), fake_packet(6), fake_packet(6), fake_packet(6), fake_packet(6), fake_packet(8)],
  cmp=cmp_packet_payloads,
)

# The total time is the time of the longest task: 
# tfn(fake_packet(3)) -> @add1(3) [0.3s] -> @mul2(4) [0.4s] => 0.7s
test_close(end - start, 0.7, eps=0.01)

In [None]:
async def add1(x):
  await asyncio.sleep(0.01 * x)
  return x + 1


s = sx.of(fake_packet(1))
sink = cx.as_chan(s)

t = ParDo(add1) | ParDo(add1) | ParDo(add1) | ParDo(add1)

start = time.monotonic()
got = await sx.tolist(t(sink))
end = time.monotonic()

test(
  got,
  [fake_packet(5)],
  cmp=cmp_packet_payloads,
)
# 0.01 + 0.02 + 0.03 + 0.04 = 0.1
test_close(end - start, 0.1, eps=0.01)

In [None]:
async def add1(x):
  await asyncio.sleep(0.01 * x)
  return x + 1


s = sx.of(fake_packet(1))
sink = cx.as_chan(s)

# The 1st add1 will be lifted to a transform because the 2nd add1 is a transform.s
t = add1 | ParDo(add1)

start = time.monotonic()
got = await sx.tolist(t(sink))
end = time.monotonic()

test(
  got,
  [fake_packet(3)],
  cmp=cmp_packet_payloads,
)


Test that generators can also be used in `ParDo` transforms.

In [None]:
async def mk_chunks(s: str):
  for x in s.split():
    yield x

s = sx.of(fake_packet("A B C"))
chan = cx.as_chan(s)

# The 1st add1 will be lifted to a transform because the 2nd add1 is a transform.s
t = ParDo(mk_chunks)

start = time.monotonic()
got = await sx.tolist(t(chan))
end = time.monotonic()

test(
  got,
  [fake_packet("A"), fake_packet("B"), fake_packet("C")],
  cmp=cmp_packet_payloads,
)


In [None]:
async def mk_chunks(s: str):
  for x in s.split():
    await asyncio.sleep(0.1)
    yield x

async def repeat(x):
  for _ in range(2):
    await asyncio.sleep(0.1)
    yield x

s = sx.of(fake_packet("A B"))
chan = cx.as_chan(s)

# The 1st add1 will be lifted to a transform because the 2nd add1 is a transform.s
t = ParDo(mk_chunks) | ParDo(repeat)

start = time.monotonic()
got = await sx.tolist(t(chan))
end = time.monotonic()

test(
  got,
  [fake_packet("A"), fake_packet("A"), fake_packet("B"), fake_packet("B")],
  cmp=cmp_packet_payloads,
)

# Execution sketch:
#  - mk_chunks("A B") (@t=0)
#     - yield "A" (@t=0.1)
#        - repeat("A") (@t=0.1)
#           - yield "A" (@t=0.2)
#           - yield "A" (@t=0.3)  
#     - yield "B" (@t=0.2)
#        - repeat("B") (@t=0.2)
#           - yield "B" (@t=0.3)
#           - yield "B" (@t=0.4)
test_close(end - start, 0.4, eps=0.01)

### SeqDo Transform

In [None]:
#| export


class SeqDo(Transform[_I, _O]):
  """Processes each element in the input channel using a user-defined function."""

  def __init__(self, fn):  # FIXME: type hint
    assert inspect.isasyncgenfunction(fn) or asyncio.iscoroutinefunction(
        fn), f"Expected an async function, got {fn}"
    self._fn = fn

  def __call__(self, chan: cx.Channel[_I]) -> cx.Channel[_O]:
    writer = sx.InMemStreamWriter()

    async def proc(chan):
      try:
        async for p in chan:
          assert isinstance(p, cx.Packet)
          if self._is_passthrough(p):
            await writer.put(p)
            continue
          fn = sx.streamify(self._fn)
          # FIXME: This loop blocks side packets from being processed.
          async for e in fn(p.payload):
            await writer.put(
                cx.Packet(
                    payload=e,
                    packet_type=cx.PacketType.DATA,
                    parent_packet_id=p.packet_id,
                    tags=p.tags,
                ),)
      finally:
        await writer.shutdown()

    asyncio.create_task(proc(chan)).add_done_callback(_print_task_errors)

    return cx.as_chan(writer.readonly())

  def _is_passthrough(self, p: cx.Packet) -> bool:
    return p.packet_type != cx.PacketType.DATA

#### SeqDo tests

In [None]:
async def add1(x):
  await asyncio.sleep(0.1)
  return x + 1

s = sx.of(fake_packet(0), fake_packet(10), fake_packet(100))
sink = cx.as_chan(s)

t = SeqDo(add1) | SeqDo(add1)

start = time.monotonic()
got = await sx.tolist(t(sink))
end = time.monotonic()

test(
  got,
  [fake_packet(2), fake_packet(12), fake_packet(102)],
  cmp=cmp_packet_payloads,
)

# The total time is the time of the longest task:
# packet 1: add1(0) -> add1(1) => 2
# packet 2:         -> add1(10) -> add1(11) => 12
# packet 3:                     -> add1(100) -> add1(101) => 102
# Each step takes 0.1s, so the total time is 0.4s.
test_close(end - start, 0.4, eps=0.05)

In [None]:
async def mk_chunks(s: str):
  for x in s.split():
    await asyncio.sleep(0.1)
    yield x

async def tolower(s: str):
  await asyncio.sleep(0.15)
  return s.lower()

s = sx.of(fake_packet("A B"))
sink = cx.as_chan(s)

t = SeqDo(mk_chunks) | SeqDo(tolower)

start = time.monotonic()
got = await sx.tolist(t(sink))
end = time.monotonic()

test(
  got,
  [fake_packet("a"), fake_packet("b")],
  cmp=cmp_packet_payloads,
)

# The total time is the time of the longest task:
# "A B" -> mk_chunks("A B") 
#            |_ "A" [T=0.1] -> tolower("A") -> "a" [T=0.25]
#            |_ "B" [T=0.2] ->    WAITING   -> tolower("B") -> "b" [T=0.4]
test_close(end - start, 0.4, eps=0.01)


### CancelPrev Transform

In [None]:
#| export


class CancelPrev(Transform[_I, _O]):
  """Cancels previous packets and their derivatives when a new packet arrives.

  This is useful when you want to ensure that only the latest packet is processed.
  For example to avoid double texting in a chat application: When I user sends a new message,
  while the previous message is still being processed, we may want to cancel the processing of
  the previous message.
  """

  def __call__(self, chan: cx.Channel[_I]) -> cx.Channel[_O]:
    writer = sx.InMemStreamWriter()

    async def proc(chan):
      abort_tag = ""
      try:
        async for p in chan:
          assert isinstance(p, cx.Packet)

          # We broadcast a cancellation packet that targtets the previous
          # packet and its derivatives.
          if abort_tag:
            await writer.put(cx.mk_cancellation_packet(tag=abort_tag))

          # Compute a new abort tag for the next packet.
          abort_tag = f"latch-{str(uuid.uuid4())}"
          await writer.put(
              cx.Packet(
                  payload=p.payload,
                  packet_type=cx.PacketType.DATA,
                  parent_packet_id=p.packet_id,
                  tags=(*p.tags, abort_tag),
              ),)
      finally:
        await writer.shutdown()

    asyncio.create_task(proc(chan)).add_done_callback(_print_task_errors)

    return cx.as_chan(writer.readonly())

#### Latch Tests

In [None]:
async def add1(x):
  await asyncio.sleep(0.1)
  return x + 1


t = CancelPrev() | ParDo(add1)

s = sx.of(fake_packet(0), fake_packet(10))
chan = cx.as_chan(s)

got = await sx.tolist(sx.filter(
  lambda x: x.packet_type == cx.PacketType.DATA,
   t(chan),
))

test(
  got,
  [fake_packet(11)],
  cmp=cmp_packet_payloads,
)

## The TFN lift decorator

In [None]:
#| export


@dataclasses.dataclass(frozen=True)
class Event:
  payload: Any
  src: str = ""

In [None]:
#| export

_R = TypeVar('_R')
_P = ParamSpec('_P')


class Streamable(Protocol[_P, _R]):

  def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
    ...

  def stream(self, *args: _P.args, **kwargs: _P.kwargs) -> cx.Channel[Event]:
    ...

  def __or__(self, other) -> Transform:
    ...

In [None]:
#| export

_sink_ctxvar = contextvars.ContextVar('_sink_contextvar', default=None)


@contextlib.contextmanager
def use_sink(sink: sx.StreamWriter):
  try:
    tok = _sink_ctxvar.set(sink)
    yield sink
  finally:
    _sink_ctxvar.reset(tok)


def cur_sink() -> sx.StreamWriter | None:
  return _sink_ctxvar.get()

In [None]:
#| export

# FIXME How to improve the type hinting for decorated @tfn functions? (e.g., keep their signature).


def tfn(fn: Callable[..., Awaitable[_R]]) -> Streamable:
  assert asyncio.iscoroutinefunction(fn) or inspect.isasyncgenfunction(
      fn), "tfn can only be used with async functions or async generators"

  class _S(Streamable):

    def __init__(self):
      self._instance = None

    def __get__(self, instance, owner):
      """Ensure correct binding for instance methods."""
      if instance:
        self._instance = instance
      return self

    # TODO: Factor out the common code between __call__s.
    if inspect.isasyncgenfunction(fn):

      @functools.wraps(fn)
      async def __call__(self, *args, **kwargs):
        """Handles async generators"""
        sink = kwargs.pop("sink", None)
        if not sink:
          sink = cur_sink()

        if "sink" in inspect.signature(fn).parameters:
          kwargs["sink"] = sink

        if self._instance:
          # This required for decorated instance methods.
          args = (self._instance, *args)

        async for e in fn(*args, **kwargs):
          yield e  # Async generator case
    else:

      @functools.wraps(fn)
      async def __call__(self, *args, **kwargs):
        """Handles normal async functions"""
        sink = kwargs.pop("sink", None)
        if not sink:
          sink = cur_sink()

        if "sink" in inspect.signature(fn).parameters:
          kwargs["sink"] = sink

        if self._instance:
          # This required for decorated instance methods.
          args = (self._instance, *args)

        return await fn(*args, **kwargs)  # Normal async function case

    def stream(self, *args, return_value: bool = False, **kwargs):
      """Returns a streamable version of the function."""
      sink = sx.InMemStreamWriter()
      with use_sink(sink):

        async def target():
          nonlocal sink
          try:
            result = await self(
                *args, **kwargs,
                sink=sink)  # FIXME Should we overwrite chan if already passed?
            if return_value:
              await sink.put(result)
          finally:
            await sink.shutdown()

        # TODO: We probably need a task cleanup.
        asyncio.create_task(target()).add_done_callback(_print_task_errors)
        return sink.readonly()

    def __or__(self, other) -> Transform:
      t1, t2 = as_transform(self), as_transform(other)
      return t1 | t2

    def __ror__(self, other) -> Transform:
      t2, t1 = as_transform(self), as_transform(other)
      return t1 | t2

  wrapped = _S()
  if asyncio.iscoroutinefunction(fn):
    inspect.markcoroutinefunction(wrapped)

  return wrapped

### TFN Tests

In [None]:
@tfn
async def add1(x: int):
  await asyncio.sleep(0.01)
  return x + 1

test_eq(asyncio.iscoroutinefunction(add1), True)

In [None]:
class C:
  @tfn
  async def add1(self, x: int):
    await asyncio.sleep(0.01)
    return x + 1

c = C()
test_eq(asyncio.iscoroutinefunction(c.add1), True)

In [None]:
# FIXME: This test is failing.

# @tfn
# async def add1(x: int):
#   await asyncio.sleep(0.01)
#   yield x + 1

# test_eq(inspect.isasyncgenfunction(add1), True)

In [None]:
@tfn
async def add1(x: int):
  await asyncio.sleep(0.01)
  return x + 1

# Call the function directly.
test_eq(await add1(1), 2)

# Stream the function.
s = await sx.tolist(add1.stream(1))
test_eq([], s)

# Stream the function and include the return value.
s = await sx.tolist(add1.stream(1, return_value=True))
test_eq(s, [2])

In [None]:
class C:
  @tfn
  async def add1(self, x: int):
    await asyncio.sleep(0.01)
    return x + 1

c = C()
test_eq(await c.add1(1), 2)

s = await sx.tolist(c.add1.stream(1))
test_eq([], s)

s = await sx.tolist(c.add1.stream(1, return_value=True))
test_eq(s, [2])

In [None]:
@tfn
async def add1(x: int):

  @tfn
  async def _do(x, sink: sx.StreamWriter | None = None):
    await asyncio.sleep(0.01)
    if sink:
      await sink.put("A")
      await sink.put("B")
    return x + 1

  return await _do(x)

# Call the function directly.
test_eq(await add1(1), 2)

# Stream the function.
s = await sx.tolist(add1.stream(1))
test_eq(s, ["A", "B"])

# Stream the function and include the return value.
s = await sx.tolist(add1.stream(1, return_value=True))
test_eq(s, ["A", "B", 2])

In [None]:
class C:

  @tfn
  async def add1(self, x: int):

    @tfn
    async def _do(x, sink: sx.StreamWriter | None = None):
      await asyncio.sleep(0.01)
      if sink:
        await sink.put("A")
        await sink.put("B")
      return x + 1

    return await _do(x)

c = C()

# Call the method directly.
test_eq(await c.add1(1), 2)

# Stream the method.
s = await sx.tolist(add1.stream(1))
test_eq(s, ["A", "B"])

# Stream the method and include the return value.
s = await sx.tolist(add1.stream(1, return_value=True))
test_eq(s, ["A", "B", 2])

In [None]:
@tfn
async def mk_chunks(s: str):
  for x in s.split():
    await asyncio.sleep(0.1)
    yield x

@tfn
async def repeat(x):
  for _ in range(2):
    await asyncio.sleep(0.1)
    yield x

s = sx.of(fake_packet("A B"))
chan = cx.as_chan(s)

t = mk_chunks | repeat

start = time.monotonic()
got = await sx.tolist(t(chan))
end = time.monotonic()

test(
  got,
  [fake_packet("A"), fake_packet("A"), fake_packet("B"), fake_packet("B")],
  cmp=cmp_packet_payloads,
)

# Execution sketch:
#  - mk_chunks("A B") (@t=0)
#     - yield "A" (@t=0.1)
#        - repeat("A") (@t=0.1)
#           - yield "A" (@t=0.2)
#           - yield "A" (@t=0.3)  
#     - yield "B" (@t=0.2)
#        - repeat("B") (@t=0.2)
#           - yield "B" (@t=0.3)
#           - yield "B" (@t=0.4)
test_close(end - start, 0.4, eps=0.01)

In [None]:
# TODO: P0 Agent runner
# TODO: P1 Add a playground (text msgs)
# TODO: P1 Add wrapper to OpenAI realtime API.
# TODO: P1 Add a simple tracer.
# TODO: P2 Make .stream liftable to ParDo transform.

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()