Skip to content

Commit

Permalink
switched workers to websockets
Browse files Browse the repository at this point in the history
  • Loading branch information
yk committed Jan 21, 2023
1 parent da26315 commit 0726913
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 109 deletions.
2 changes: 1 addition & 1 deletion inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Preliminary implementation of the inference engine for OpenAssistant.

## Development (you'll need multiple terminals)

Run a redis container:
Run a redis container (or use the one of the general docker compose file):

```bash
docker run --rm -it -p 6379:6379 redis
Expand Down
11 changes: 5 additions & 6 deletions inference/server/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# OpenAssistant Inference Server

Workers communicate with the `/work` endpoint. They provide their configuration
and if a task is available, the server returns it. The server also returns the
key of a Redis list where the worker should push the results.
Workers communicate with the `/work` endpoint via Websocket. They provide their
configuration and if a task is available, the server returns it. The worker then
performs the task and returns the result in a streaming fashion to the server,
also via websocket.

Clients communicate first with the `/complete` endpoint to place a request for
prompt completion. The server returns a unique ID for the request. The client
then polls the `/stream` endpoint with the ID to check if the request has been
assigned to a worker. Once it is assigned, the response will be a SSE event
source.
then calls the `/stream` endpoint with the ID to get an SSE event source.

Notably, `/complete` could be proxied via a frontend, while `/stream` can be
accessed directly by the client, since the unique ID provides enough security.
101 changes: 53 additions & 48 deletions inference/server/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import enum
import random
import uuid
Expand Down Expand Up @@ -48,7 +49,7 @@ def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:


class CompletionResponse(pydantic.BaseModel):
completion_id: str
id: str


class ResponseEvent(pydantic.BaseModel):
Expand All @@ -62,9 +63,8 @@ class CompletionState(str, enum.Enum):


class DbEntry(pydantic.BaseModel):
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
completion_request: CompletionRequest
completion_id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
stream_queue_id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1))
result_data: list[inference.WorkResponsePacket] | None = None
state: CompletionState = CompletionState.pending
Expand All @@ -82,56 +82,61 @@ async def complete(request: CompletionRequest) -> CompletionResponse:
db_entry = DbEntry(
completion_request=request,
)
DATABASE[db_entry.completion_id] = db_entry
return CompletionResponse(completion_id=db_entry.completion_id)


class StartWorkRequest(pydantic.BaseModel):
worker_config: inference.WorkerConfig


@app.post("/work")
async def work(start_work_request: StartWorkRequest) -> inference.WorkRequest:
"""Allows a worker to request work, given its configuration."""

# find a pending request that is compatible with the worker config
# this could be implemented using queues, databases, long polling, etc.
# we might also think about replacing this endpoint with a message queue
# but we need to know which worker is dequeueing a particular request
# to do proper credit assignment and load balancing (+ security)
for db_entry in DATABASE.values():
if db_entry.state == CompletionState.pending:
if db_entry.completion_request.compatible_with(start_work_request.worker_config):
break
else:
raise fastapi.HTTPException(status_code=202, detail="No pending requests")

request = db_entry.completion_request

work_request = inference.WorkRequest(
stream_queue_id=db_entry.stream_queue_id,
prompt=request.prompt,
model_name=request.model_name,
max_length=request.max_length,
seed=db_entry.seed,
)

logger.info(f"Created {work_request}")
db_entry.state = CompletionState.in_progress
return work_request
DATABASE[db_entry.id] = db_entry
return CompletionResponse(id=db_entry.id)


@app.websocket("/work")
async def work(websocket: fastapi.WebSocket):
await websocket.accept()
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
while True:
# find a pending task that matches the worker's config
# could also be implemented using task queues
# but general compatibility matching is tricky
for db_entry in DATABASE.values():
if db_entry.state == CompletionState.pending:
if db_entry.completion_request.compatible_with(worker_config):
break
else:
logger.debug("No pending tasks")
await asyncio.sleep(1)
continue

request = db_entry.completion_request

work_request = inference.WorkRequest(
prompt=request.prompt,
model_name=request.model_name,
max_length=request.max_length,
seed=db_entry.seed,
)

logger.info(f"Created {work_request}")
db_entry.state = CompletionState.in_progress
try:
await websocket.send_text(work_request.json())
while True:
# maybe unnecessary to parse and re-serialize
# could just pass the raw string and mark end via empty string
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
await redisClient.rpush(db_entry.id, response_packet.json())
if response_packet.is_end:
break
except fastapi.WebSocketException:
# TODO: handle this better
logger.exception(f"Websocket closed during handling of {db_entry.id}")


@app.get("/stream/{db_id}")
async def message_stream(db_id: str, request: fastapi.Request):
@app.get("/stream/{id}")
async def message_stream(id: str, request: fastapi.Request):
"""Allows the client to stream the results of a request."""

db_entry = DATABASE[db_id]
db_entry = DATABASE[id]

if db_entry.state not in (CompletionState.pending, CompletionState.in_progress):
raise fastapi.HTTPException(status_code=404, detail="Request not found")

stream_queue_id = db_entry.stream_queue_id

async def event_generator():
result_data = []

Expand All @@ -141,7 +146,7 @@ async def event_generator():
logger.warning("Client disconnected")
break

item = await redisClient.blpop(stream_queue_id, 1)
item = await redisClient.blpop(db_entry.id, 1)
if item is None:
continue

Expand All @@ -156,9 +161,9 @@ async def event_generator():
"retry": settings.sse_retry_timeout,
"data": ResponseEvent(token=response_packet.token).json(),
}
logger.info(f"Finished streaming {stream_queue_id} {len(result_data)=}")
logger.info(f"Finished streaming {db_entry.id} {len(result_data)=}")
except Exception:
logger.exception(f"Error streaming {stream_queue_id}")
logger.exception(f"Error streaming {db_entry.id}")

# store the generated data in the database
db_entry.result_data = result_data
Expand Down
1 change: 1 addition & 0 deletions inference/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ loguru
pydantic
redis
sse-starlette
websockets
12 changes: 3 additions & 9 deletions inference/text-client/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Simple REPL frontend."""

import json
import time

import requests
import sseclient
Expand All @@ -15,19 +14,14 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
"""Simple REPL client."""
while True:
prompt = typer.prompt("Enter text to complete").strip()
complete_response = requests.post(f"{backend_url}/complete", json={"prompt": prompt}).json()
completion_id = complete_response["completion_id"]
id = requests.post(f"{backend_url}/complete", json={"prompt": prompt}).json()["id"]

# wait for stream to be ready
# could implement a queue position indicator
# could be implemented with long polling
# but server load needs to be considered
while True:
headers = {"Accept": "text/event-stream"}
response = requests.get(f"{backend_url}/stream/{completion_id}", stream=True, headers=headers)
if response.status_code == 200:
break
time.sleep(0.25)
response = requests.get(f"{backend_url}/stream/{id}", stream=True, headers={"Accept": "text/event-stream"})
response.raise_for_status()

client = sseclient.SSEClient(response)
for event in client.events():
Expand Down
82 changes: 40 additions & 42 deletions inference/worker/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import re
import time

import redis
import requests
import rel
import torch
import typer
import websocket
from loguru import logger
from oasst_shared.schemas import inference
from transformers import pipeline
Expand All @@ -14,56 +14,54 @@

@app.command()
def main(
backend_url: str = "ws://localhost:8000",
model_name: str = "distilgpt2",
backend_url: str = "http://localhost:8000",
redis_host: str = "localhost",
redis_port: int = 6379,
redis_db: int = 0,
):
pipe = pipeline("text-generation", model=model_name)

redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db, decode_responses=True)
def on_open(ws: websocket.WebSocket):
worker_config = inference.WorkerConfig(model_name=model_name)
ws.send(worker_config.json())

worker_config = inference.WorkerConfig(model_name=model_name)
def on_message(ws: websocket.WebSocket, message: str):
# TODO: what if this comes in, but one is already in progress?
# also need to think of enabling batching
work_request = inference.WorkRequest.parse_raw(message)

pipe = pipeline("text-generation", model=model_name)
# TODO: replace this with incremental generation
torch.manual_seed(work_request.seed)
model_output = pipe(
work_request.prompt, max_length=work_request.max_length, do_sample=True, return_full_text=False
)[0]["generated_text"]
model_output = model_output.strip()

# TODO: use batching
while True:
# fake streaming
split_idcs = [m.start() for m in re.finditer(r"(\w+)", model_output)]
pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
for piece in pieces:
if not piece:
continue
ws.send(inference.WorkResponsePacket(token=piece).json())
time.sleep(0.1)
ws.send(inference.WorkResponsePacket(is_end=True).json())

# wait for work to be ready
# could possibly be switched to a message queue
while True:
try:
response = requests.post(f"{backend_url}/work", json={"worker_config": worker_config.dict()})
if response.status_code == 200:
break
except Exception:
logger.exception("Error connecting to backend")
time.sleep(1)
def on_error(ws: websocket.WebSocket, error: Exception):
logger.error(f"Connection error: {error}")

try:
work_request = inference.WorkRequest.parse_raw(response.content)
print(f"Processing {work_request}")
queue_id = work_request.stream_queue_id
def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str):
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}")

# TODO: replace this with incremental generation
torch.manual_seed(work_request.seed)
model_output = pipe(
work_request.prompt, max_length=work_request.max_length, do_sample=True, return_full_text=False
)[0]["generated_text"]
model_output = model_output.strip()
ws = websocket.WebSocketApp(
f"{backend_url}/work",
on_message=on_message,
on_error=on_error,
on_close=on_close,
on_open=on_open,
)

# fake streaming
split_idcs = [m.start() for m in re.finditer(r"(\w+)", model_output)]
pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
for piece in pieces:
if not piece:
continue
redis_client.rpush(queue_id, inference.WorkResponsePacket(token=piece).json())
time.sleep(0.1)
redis_client.rpush(queue_id, inference.WorkResponsePacket(is_end=True).json())
except Exception:
logger.exception(f"Error processing {work_request}")
ws.run_forever(dispatcher=rel, reconnect=5)
rel.signal(2, rel.abort)
rel.dispatch()


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions inference/worker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
loguru
redis
requests
rel
torch
transformers
typer
websocket-client
1 change: 0 additions & 1 deletion oasst-shared/oasst_shared/schemas/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ class WorkerConfig(pydantic.BaseModel):


class WorkRequest(pydantic.BaseModel):
stream_queue_id: str
prompt: str = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
max_length: int = 100
Expand Down

0 comments on commit 0726913

Please sign in to comment.