# LLMs

> Client for interacting with LLMs

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| default_exp llms

In [None]:
#| export

import abc
from dataclasses import dataclass, field
from dataclasses_json import dataclass_json, config
from typing import Any, Sequence
import io
import base64

import openai
import msglm
from PIL import Image
from fastcore import imghdr

from fastagent_hacking import transforms as tx
from fastagent_hacking import channels as cx
from fastagent_hacking import streams as sx

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
from fastcore.test import *

In [None]:
from IPython.core.magic import register_cell_magic
import nest_asyncio

nest_asyncio.apply()  # Required to use run_if magic with async code.


@register_cell_magic
def run_if(line, cell):
  if eval(line, globals()):
    get_ipython().run_cell(cell)


## LLM API

In [None]:
#| export

# The basic unit of a message is either a string, image or raw bytes.
_MsgLeafContent = str | Image.Image | bytes

# A message can be a single leaf content or a sequence of leaf contents.
MsgContent = _MsgLeafContent | Sequence[_MsgLeafContent]

In [None]:
#| export

# Utils for encoding/decoding messages.

_TYPE_KEY = "__type__"


def _encode(content: MsgContent) -> Any:
  if isinstance(content, Image.Image):
    buff = io.BytesIO()
    content.save(buff, format='PNG')
    return {
        _TYPE_KEY: "PIL.Image",
        "data": base64.b64encode(buff.getvalue()).decode()
    }
  elif isinstance(content, bytes):
    return {"__type__": "bytes", "data": base64.b64encode(content).decode()}
  elif isinstance(content, (list, tuple)):
    return [_encode(item) for item in content]
  elif isinstance(content, str):
    return content

  raise ValueError(f"Cannot serialize {content} with type {type(content)}")


def _decode(content: Any) -> MsgContent:
  if isinstance(content, dict) and _TYPE_KEY in content:
    if content[_TYPE_KEY] == "PIL.Image":
      bs = base64.b64decode(content["data"])
      return Image.open(io.BytesIO(bs))
    elif content[_TYPE_KEY] == "bytes":
      return base64.b64decode(content["data"])
  elif isinstance(content, list):
    return [_decode(item) for item in content]
  elif isinstance(content, str):
    return content

  raise ValueError(f"Cannot deserialize {content} with type {type(content)}")

In [None]:
#| export


@dataclass_json
@dataclass(frozen=True)
class Msg:
  """A message in a chat.

  Attributes:

    role: Dictates the purpose and perspective of the message. 
      For example, 'user', 'system' or 'assistant'.
    content: The content of the message.
    name: Optional. Associates the message to a named entity.
      It doesn't have any effect on the LLM output. Defaults to empty string.
  """
  role: str
  content: MsgContent = field(metadata=config(
      encoder=_encode,
      decoder=_decode,
  ))
  name: str = ""

In [None]:
#| export


@dataclass_json
@dataclass(frozen=True)
class MsgChunk:
  role: str
  content: MsgContent = field(metadata=config(
      encoder=_encode,
      decoder=_decode,
  ))
  end: bool
  name: str = ""

In [None]:
import numpy as np

In [None]:
# Test Serialization/Deserialization of messages

# str content
msg = Msg(role="user", content="Hello!")
test_eq(msg, Msg.from_json(msg.to_json()))

# bytes content
msg = Msg(role="ai", content=b"12345")
test_eq(msg, Msg.from_json(msg.to_json()))

# PIL Image content
img = Image.new("RGB", (100, 100), color=1)
msg = Msg(role="ai", content=img)
m = Msg.from_json(msg.to_json())
test_eq(np.array(m.content), np.array(img))

# list content
msg = Msg(role="ai", content=["Hello", b"12345"])
test_eq(msg, Msg.from_json(msg.to_json()))

In [None]:
# Test Serialization/Deserialization of message chunks

msg = MsgChunk(
    role="user",
    content=["Hello!", b"12345676"],
    end=True,
    name="ai",
)
test_eq(msg, MsgChunk.from_json(msg.to_json()))

In [None]:
#| export

MsgLike = Msg | MsgContent

In [None]:
#| export


class Backend(abc.ABC):

  @abc.abstractmethod
  async def chat(
      self,
      msgs: Sequence[MsgLike],
      *,
      name: str = "",
      temperature: float | None = None,
      sink=None,
  ) -> Msg:
    """Returns a chat response given a sequence of messages.
    
    Note: This method is stateless. It means that you must always
      provide the full chat history.
    
    Args:
      msgs: A sequence of messages. If a message is a `MsgContent`,
        the 'user' equivalent role will be assumed.
      name: Optional name to the chat assistant.
        It doesn't have any effect on the LLM output. Defaults to empty string.
      temperature: Optional. The temperature of the response.
        If None, the backend will use its default value.
      sink: Internal use only. Defaults to None.
    """

  # TODO: Add emebd method.

## OpenAI Backend

In [None]:
#| export


class OpenaiAPI(Backend):

  def __init__(self, *, model: str, api_key: str | None = None):
    self._client = openai.AsyncOpenAI(api_key=api_key)
    self._model = model

  @tx.tfn
  async def chat(
      self,
      msgs: Sequence[MsgLike],
      *,
      name: str = "",
      temperature: float | None = None,
      sink=None,
  ) -> Msg:
    stream = await self._client.chat.completions.create(
        messages=[self._to_openai_msg(msg) for msg in msgs],
        model=self._model,
        temperature=temperature,
        stream=True,
    )
    content = ""
    async for chunk in stream:
      [choice] = chunk.choices
      delta = choice.delta.content or ""
      end = choice.finish_reason is not None
      content += delta
      if sink:
        await sink.put(
            MsgChunk(
                role="assistant",
                content=delta,
                end=end,
                name=name,
            ))
    return Msg(role="assistant", content=content, name=name)

  def _to_openai_msg(self, msg: Msg | MsgContent) -> dict:
    data = msg.content if isinstance(msg, Msg) else msg
    if isinstance(data, _MsgLeafContent):
      data = [data]

    chunks = []
    for d in data:
      if isinstance(d, str):
        chunks.append(d)
      elif isinstance(d, Image.Image):
        buff = io.BytesIO()
        d.save(buff, format="PNG")
        chunks.append(buff.getvalue())
      elif isinstance(d, bytes) and bool(imghdr.what(None, d)):
        chunks.append(d)
      else:
        raise ValueError(f"Invalid message content: {d}")

    role = msg.role if isinstance(msg, Msg) else "user"

    return msglm.mk_msg(chunks, role=role, api="openai")


In [None]:
import httpx
import os

In [None]:
import dotenv
dotenv.load_dotenv()

True

In [None]:
%%run_if os.environ.get("OPENAI_API_KEY")

llm = OpenaiAPI(model="gpt-4o-mini")

In [None]:
%%run_if os.environ.get("OPENAI_API_KEY")

ai_msg = await llm.chat(["Hi my name is Achraf"])
ai_msg

Msg(role='assistant', content='Hi Achraf! How can I assist you today?', name='')

In [None]:
%%run_if os.environ.get("OPENAI_API_KEY")

# The `chat` API is stateless, we need to provide the full chat history.
ai_msg = await llm.chat([
  "Hi my name is Achraf", 
  ai_msg, 
  "what's my name?",
])
ai_msg

Msg(role='assistant', content='Your name is Achraf. How can I help you today, Achraf?', name='')

In [None]:
%%run_if os.environ.get("OPENAI_API_KEY")

# `chat` API can handle byte images.
img_url = "https://www.atshq.org/wp-content/uploads/2022/07/shutterstock_1626122512.jpg"
img = httpx.get(img_url).content

ai_msg = await llm.chat(["What do you see in the following image?", img])
ai_msg

Msg(role='assistant', content='The image features a toucan, known for its vibrant and distinctive beak. The bird has a black body with a bright yellow throat and colorful beak, showcasing various shades of green, red, and orange. The background appears to be blurred, emphasizing the toucan’s vivid colors and details.', name='')

In [None]:
%%run_if os.environ.get("OPENAI_API_KEY")

# `chat` API can handle PIL images.
pil_img = Image.open(io.BytesIO(img))

ai_msg = await llm.chat(["What do you see in the following image?", img])
ai_msg

Msg(role='assistant', content="The image depicts a toucan, characterized by its distinctive large, colorful beak. The bird has a black body, bright yellow throat, and vibrant green, orange, and red hues on its beak. The background appears to be blurred, enhancing the focus on the toucan's striking features.", name='')

In [None]:
%%run_if os.environ.get("OPENAI_API_KEY")

# `chat` can be streamed.
async for chunk in llm.chat.stream(
    ["Generate a short poem about about AI"],
    temperature=0.7,
):
  print(chunk.content, end="", flush=True)

assert chunk.end, "Last chunk should be the end of the response."

In circuits spun where thoughts align,  
A spark of code, a thread divine.  
With whispers soft of ones and zeroes,  
AI emerges, our digital hero.  

It learns and grows, a mind anew,  
In data's dance, it finds its view.  
From art to words, it weaves its tale,  
A partner bright, where dreams set sail.  

Yet in its glow, we ponder deep,  
What truths we share, what bounds we keep.  
For in this age of silicon dreams,  
We navigate the light and shadows' beams.  

## Chat Object

In [None]:
#| export


class Chat(tx.Transform[MsgLike, MsgChunk]):

  def __init__(
      self,
      backend: Backend,
      history: Sequence[MsgLike] = [],  # TODO: Add possibility to load from DB.
      name: str = "",
  ):
    # TODO: Add configuration for the temperature.
    # TODO: Add possibility to send full Msg not just chunks.
    self._backend = backend
    self._history = list(history)
    self._name = name

  def __call__(self, chan: cx.Channel[MsgLike]) -> cx.Channel[MsgChunk]:
    p = tx.Latch() | self.chat
    return p(chan)

  async def chat(self, msg: MsgLike):
    if not isinstance(msg, Msg):
      msg = Msg(role="user", content=msg)

    resp = ""
    async for chunk in self._backend.chat.stream(
        self._history + [msg],
        name=self._name,
    ):
      resp = self._merge_content(new=chunk.content, prev=resp)
      yield chunk

    # Only record the history if the chat completion ends because
    # chats can be interrupted mid turns.
    self._history.extend((
        msg,
        Msg(role="assistant", content=resp, name=self._name),
    ))

  def _merge_content(
      self,
      *,
      new: MsgContent,
      prev: MsgContent,
  ) -> MsgContent:
    assert isinstance(new, type(prev)), f"Cannot merge {new} with {prev}"
    assert isinstance(
        new, (str, bytes)), f"Cannot merge {prev} with type {type(prev)}"
    return prev + new


In [None]:
%%run_if os.environ.get("OPENAI_API_KEY")

chat = Chat(
    llm,
    history=[
        "Hi my name is Achraf",
        Msg(role="assistant", content="Nice to meet you!"),
    ],
    name="ai",
)


s = sx.of(
  cx.Packet(payload="Write a haiku about Physics.", packet_type=cx.PacketType.DATA),
  cx.Packet(payload="Write a haiku about Python", packet_type=cx.PacketType.DATA),
)
ch = cx.as_chan(s)

async for p in chat(ch):
  if p.packet_type != cx.PacketType.DATA:
    continue
  chunk = p.payload
  print(chunk.content, end="", flush=True)

# The history should contain 4 messages.
#   - 2 seeded messages: ["Hi my name is Achraf", "Nice to meet you!"]
#   - ["Write a haiku about Physics", <answer>] are dropped because this turn is interrupted
#      by the next message.
#   - ["Write a haiku about Python", <answer>] is the last turn.
test_eq(len(chat._history), 4)

Code flows like a stream,  
Indentations guide the flow,  
Logic weaves through lines.

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()