Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
liamgriffiths committed Jun 18, 2024
1 parent c38a66f commit 92bc89b
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 6 deletions.
12 changes: 8 additions & 4 deletions examples/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@
if api_key is None:
raise EnvironmentError("No SUBSTRATE_API_KEY set")

from substrate import Substrate, GenerateText
from substrate import Substrate, GenerateText, sb

substrate = Substrate(api_key=api_key, timeout=60 * 5)

story = GenerateText(prompt="tell me a story")
# summary = GenerateText(prompt=sb.concat("Summarize this story: ", story.future.text))
summary = GenerateText(prompt=sb.concat("Summarize this story: ", story.future.text))

# response = substrate.run(story, summary)
response = substrate.run(story)
response = substrate.run(story, summary)
print(response)

print("=== story")
story_out = response.get(story)
print(story_out.text)

print("=== summary")
summary_out = response.get(summary)
print(summary_out.text)

Expand Down
38 changes: 38 additions & 0 deletions examples/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import sys
import asyncio
from pathlib import Path

# add parent dir to sys.path to make 'substrate' importable
parent_dir = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(parent_dir))

api_key = os.environ.get("SUBSTRATE_API_KEY")
if api_key is None:
raise EnvironmentError("No SUBSTRATE_API_KEY set")

from substrate import Substrate, GenerateText

substrate = Substrate(api_key=api_key, timeout=60 * 5)


a = GenerateText(prompt="tell me about windmills", max_tokens=10)
b = GenerateText(prompt="tell me more about cereal", max_tokens=10)


async def amain():
response = await substrate.async_stream(a, b)
async for event in response.async_iter_events():
print(event)


asyncio.run(amain())


def main():
response = substrate.stream(a, b)
for message in response.iter_events():
print(message)


main()
15 changes: 13 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ httpx = ">=0.26.0"
distro = ">=1.8.0"
typing-extensions = "^4.10.0"
pydantic = ">=1.0.0"
httpx-sse = "^0.4.0"

[tool.ruff.lint]
ignore-init-module-imports = true
Expand Down
31 changes: 31 additions & 0 deletions substrate/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import httpx
import distro
import httpx_sse

from ._version import __version__
from .core.id_generator import IDGenerator
Expand Down Expand Up @@ -172,6 +173,13 @@ def default_headers(self) -> Dict[str, str]:
**self._additional_headers,
}

@property
def streaming_headers(self) -> Dict[str, str]:
headers = self.default_headers
headers["Accept"] = "text/event-stream"
headers["X-Substrate-Streaming"] = "1"
return headers

def post_compose(self, dag: Dict[str, Any]) -> APIResponse:
url = f"{self._base_url}/compose"
body = {"dag": dag}
Expand All @@ -190,6 +198,29 @@ def post_compose(self, dag: Dict[str, Any]) -> APIResponse:
)
return res

def post_compose_streaming(self, dag: Dict[str, Any]):
url = f"{self._base_url}/compose"
body = {"dag": dag}

def iterator():
with httpx.Client(timeout=self._timeout, follow_redirects=True) as client:
with httpx_sse.connect_sse(client, "POST", url, json=body, headers=self.streaming_headers) as event_source:
for sse in event_source.iter_sse():
yield sse
return iterator()


async def async_post_compose_streaming(self, dag: Dict[str, Any]):
url = f"{self._base_url}/compose"
body = {"dag": dag}

async def iterator():
async with httpx.AsyncClient(timeout=self._timeout, follow_redirects=True) as client:
async with httpx_sse.aconnect_sse(client, "POST", url, json=body, headers=self.streaming_headers) as event_source:
async for sse in event_source.aiter_sse():
yield sse
return iterator()

async def async_post_compose(self, dag: Dict[str, Any]) -> APIResponse:
url = f"{self._base_url}/compose"
body = {"dag": dag}
Expand Down
41 changes: 41 additions & 0 deletions substrate/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import json
from typing import Iterator, AsyncIterator

import httpx_sse


class ServerSentEvent:
def __init__(self, event: httpx_sse.ServerSentEvent):
self.event = event

@property
def data(self):
return json.loads(self.event.data)

def __repr__(self):
return self.event.__repr__()

def __str__(self):
"""
Render the Server-Sent Event as a string to be rendered in a streaming response
"""
fields = ["id", "event", "data", "retry"]
lines = [f"{field}: {getattr(self.event, field)}" for field in fields if getattr(self.event, field)]
return "\n".join(lines) + "\n"


class SubstrateStreamingResponse:
"""
Substrate stream response.
"""

def __init__(self, *, iterator):
self.iterator = iterator

def iter_events(self) -> Iterator[ServerSentEvent]:
for sse in self.iterator:
yield ServerSentEvent(sse)

async def async_iter_events(self) -> AsyncIterator[ServerSentEvent]:
async for sse in self.iterator:
yield ServerSentEvent(sse)
19 changes: 19 additions & 0 deletions substrate/substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import base64
from typing import Any, Dict

from substrate.streaming import SubstrateStreamingResponse

from ._client import APIClient
from .core.corenode import CoreNode
from .core.client.graph import Graph
Expand Down Expand Up @@ -48,6 +50,22 @@ async def async_run(self, *nodes: CoreNode) -> SubstrateResponse:
api_response = await self._client.async_post_compose(dag=serialized)
return SubstrateResponse(api_response=api_response)

def stream(self, *nodes: CoreNode) -> SubstrateStreamingResponse:
"""
Run the given nodes and receive results as Server-Sent Events.
"""
serialized = Substrate.serialize(*nodes)
iterator = self._client.post_compose_streaming(dag=serialized)
return SubstrateStreamingResponse(iterator=iterator)

async def async_stream(self, *nodes: CoreNode) -> SubstrateStreamingResponse:
"""
Run the given nodes and receive results as Server-Sent Events.
"""
serialized = Substrate.serialize(*nodes)
iterator = await self._client.async_post_compose_streaming(dag=serialized)
return SubstrateStreamingResponse(iterator=iterator)

@staticmethod
def visualize(*nodes):
"""
Expand All @@ -67,6 +85,7 @@ def serialize(*nodes):
"""

all_nodes = set()

def collect_nodes(node):
all_nodes.add(node)
for referenced_node in node.referenced_nodes:
Expand Down

0 comments on commit 92bc89b

Please sign in to comment.