# Channels

> Channel data structure.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| default_exp channels

In [None]:
#| export

import abc
import dataclasses
import enum
import json
import time
import uuid
from typing import Any, Generic, TypeVar

import fastagent_hacking.streams as sx

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
from fastcore.test import *

## Packets

In [None]:
#| export

_T = TypeVar('T')


class PacketType(enum.StrEnum):
  MAIN = enum.auto()
  LOG_PACKET = enum.auto()
  EVENT_PACKET = enum.auto()


@dataclasses.dataclass(frozen=True)
class Packet(Generic[_T]):
  """Represents a unit of data inside a Channel
  
  Attributes:
    payload: The data that the packet carries
    packet_type: The type of the packet
    packet_id: ID for the packet.
    parent_packet_id: ID of the packet that spawned this packet.
    stamp: An ID that corrolates a packet to its producer. 
        It's unique per invocation of the producer.
    created_at: The time the packet was created in milliseconds.
  """
  payload: _T
  packet_type: PacketType
  packet_id: str
  parent_packet_id: str
  stamp: str
  created_at: float

  def to_json(self):
    return json.dumps(
        self,
        default=lambda o: o.__dict__,
        sort_keys=True,
        indent=2,
    )

  @classmethod
  def from_json(cls, json_str):
    return cls(**json.loads(json_str))


In [None]:
p = Packet(
    payload=1,
    packet_type=PacketType.MAIN,
    packet_id='1',
    parent_packet_id='1',
    stamp='1',
    created_at=1,
)

test_eq(p, Packet.from_json(p.to_json()))

## Channels


In [None]:
#| export


class Channel(sx.Stream[Packet[Any]], Generic[_T]):

  @property
  @abc.abstractmethod
  def elm_type(self) -> type[_T]:
    """Main type of the elements in the channel"""


def as_chan(s: sx.Stream[Packet[Any]], elm_type: type[_T]) -> Channel[_T]:
  """Coerce a stream of packets to a channel. Do not use `s` after this function."""

  class _ChanStream(Channel[_T]):

    @property
    def elm_type(self) -> type[_T]:
      return elm_type

    async def next(self, *args, **kwargs):
      return await s.next(*args, **kwargs)

  return _ChanStream()

In [None]:
#| export


class ChannelWriter(sx.StreamWriter[Packet[Any]], Generic[_T]):
  elm_type: type[_T]  # Main packet payload type of the channel

  @abc.abstractmethod
  def readonly(self) -> Channel[_T]:
    """Return a readonly version of the channel"""


def as_chan_writer(
    s: sx.StreamWriter[Packet[Any]],
    elm_type: type[_T],
) -> ChannelWriter[_T]:
  """Coerce a stream writer of packets to a channel writer. Do not use `s` after this function."""

  class _ChanWriter(ChannelWriter[_T]):

    @property
    def elm_type(self) -> type[_T]:
      return elm_type

    async def put(self, *args, **kwargs):
      assert all(isinstance(a, Packet) for a in args)
      await s.put(*args, **kwargs)

    async def shutdown(self, *args, **kwargs):
      await s.shutdown(*args, **kwargs)

    def readonly(self, *args, **kwargs) -> Channel[_T]:
      return as_chan(s.readonly(*args, **kwargs), elm_type)

  return _ChanWriter()

### Channel tests

In [None]:
def fake_packet(payload) -> Packet:
  return Packet(
      payload=payload,
      packet_type=PacketType.MAIN,
      packet_id='1',
      parent_packet_id='1',
      stamp='1',
      created_at=1,
  )

In [None]:
chan = as_chan(
    sx.of(
        fake_packet(1),
        fake_packet("a"),
    ),
    int,
)

test_eq(isinstance(chan, Channel), True)
test_eq(chan.elm_type, int)

ps = []
async for p in chan:
  ps.append(p)
test_eq(ps, [fake_packet(1), fake_packet("a")])

In [None]:
chan_writer = as_chan_writer(sx.InMemStreamWriter(), str)

test_eq(isinstance(chan_writer, ChannelWriter), True)
test_eq(chan_writer.elm_type, str)

await chan_writer.put(
  fake_packet("a"),
  fake_packet("b"),
)
await chan_writer.shutdown()

ps = []
chan = chan_writer.readonly()
async for p in chan:
  ps.append(p)
test_eq(ps, [fake_packet("a"), fake_packet("b")])

In [None]:
c0 = as_chan(sx.of(fake_packet(0), fake_packet(1)), int)
c1 = as_chan(sx.of(fake_packet(2), fake_packet(3)), int)

ps = []
async for payload in sx.concat(c0, c1):
  ps.append(payload)

test_eq(ps, [fake_packet(i) for i in range(4)])

### Send to channels

In [None]:
#| export


async def send(payload: _T, chan: ChannelWriter[_T]):
  p = Packet(
      payload=payload,
      packet_type=PacketType.MAIN,
      packet_id=str(uuid.uuid4()),
      parent_packet_id=str(uuid.uuid4()),
      stamp='',  # FIXME: Add stamp
      created_at=time.time(),
  )
  await chan.put(p)

#### Send test

In [None]:
chan = as_chan_writer(sx.InMemStreamWriter(), int)

await send(1, chan)
await send(2, chan)
await send(3, chan)
await chan.shutdown()

chan = chan.readonly()

packets = await sx.tolist(chan)

test_eq(all(isinstance(p, Packet) for p in packets), True)
test_eq([p.payload for p in packets], [1, 2, 3])

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()