Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of the inference system #869

Merged
merged 5 commits into from
Jan 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 35 additions & 0 deletions inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# OpenAssitant Inference

Preliminary implementation of the inference engine for OpenAssistant.

## Development (you'll need multiple terminals)

Run a redis container (or use the one of the general docker compose file):

```bash
docker run --rm -it -p 6379:6379 redis
```

Run the inference server:

```bash
cd server
pip install -r requirements.txt
uvicorn main:app --reload
```

Run one (or more) workers:

```bash
cd worker
pip install -r requirements.txt
python __main__.py
```

Run the client:

```bash
cd text-client
pip install -r requirements.txt
python __main__.py
```
10 changes: 10 additions & 0 deletions inference/server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# OpenAssistant Inference Server

Workers communicate with the `/work` endpoint via Websocket. They provide their
configuration and if a task is available, the server returns it. The worker then
performs the task and returns the result in a streaming fashion to the server,
also via websocket.

Clients first call `/chat` to make a new chat, then add to that via
`/chat/<id>/message`. The response is a SSE event source, which will send tokens
as they are available.
193 changes: 193 additions & 0 deletions inference/server/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import asyncio
import enum
import uuid

import fastapi
import pydantic
import redis.asyncio as redis
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from oasst_shared.schemas import inference, protocol
from sse_starlette.sse import EventSourceResponse

app = fastapi.FastAPI()

# Allow CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


class Settings(pydantic.BaseSettings):
redis_host: str = "localhost"
redis_port: int = 6379
redis_db: int = 0

sse_retry_timeout: int = 15000


settings = Settings()

# create async redis client
redisClient = redis.Redis(
host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True
)


class CreateChatRequest(pydantic.BaseModel):
pass


class CreateChatResponse(pydantic.BaseModel):
id: str


class MessageRequest(pydantic.BaseModel):
message: str = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
max_new_tokens: int = 100

def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
return self.model_name == worker_config.model_name


class TokenResponseEvent(pydantic.BaseModel):
token: str


class MessageRequestState(str, enum.Enum):
pending = "pending"
in_progress = "in_progress"
complete = "complete"


class DbChatEntry(pydantic.BaseModel):
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation)
pending_message_request: MessageRequest | None = None
message_request_state: MessageRequestState | None = None


# TODO: make real database
CHATS: dict[str, DbChatEntry] = {}


@app.post("/chat")
async def create_chat(request: CreateChatRequest) -> CreateChatResponse:
"""Allows a client to create a new chat."""
logger.info(f"Received {request}")
chat = DbChatEntry()
CHATS[chat.id] = chat
return CreateChatResponse(id=chat.id)


@app.get("/chat/{id}")
async def get_chat(id: str) -> protocol.Conversation:
"""Allows a client to get the current state of a chat."""
return CHATS[id].conversation


@app.post("/chat/{id}/message")
async def create_message(id: str, message_request: MessageRequest, fastapi_request: fastapi.Request):
"""Allows the client to stream the results of a request."""

chat = CHATS[id]
if not chat.conversation.is_prompter_turn:
raise fastapi.HTTPException(status_code=400, detail="Not your turn")
if chat.pending_message_request is not None:
raise fastapi.HTTPException(status_code=400, detail="Already pending")

chat.conversation.messages.append(
protocol.ConversationMessage(
text=message_request.message,
is_assistant=False,
)
)

chat.pending_message_request = message_request
chat.message_request_state = MessageRequestState.pending

async def event_generator():
result_data = []

try:
while True:
if await fastapi_request.is_disconnected():
logger.warning("Client disconnected")
break

item = await redisClient.blpop(chat.id, 1)
if item is None:
continue

_, response_packet_str = item
response_packet = inference.WorkResponsePacket.parse_raw(response_packet_str)
result_data.append(response_packet)

if response_packet.is_end:
break

yield {
"retry": settings.sse_retry_timeout,
"data": TokenResponseEvent(token=response_packet.token).json(),
}
logger.info(f"Finished streaming {chat.id} {len(result_data)=}")
except Exception:
logger.exception(f"Error streaming {chat.id}")

chat.conversation.messages.append(
protocol.ConversationMessage(
text="".join([d.token for d in result_data[:-1]]),
is_assistant=True,
)
)
chat.pending_message_request = None

return EventSourceResponse(event_generator())


@app.websocket("/work")
async def work(websocket: fastapi.WebSocket):
await websocket.accept()
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
while True:
# find a pending task that matches the worker's config
# could also be implemented using task queues
# but general compatibility matching is tricky
for chat in CHATS.values():
if (request := chat.pending_message_request) is not None:
if chat.message_request_state == MessageRequestState.pending:
if request.compatible_with(worker_config):
break
else:
logger.debug("No pending tasks")
await asyncio.sleep(1)
continue

chat.message_request_state = MessageRequestState.in_progress
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we have >1 "message-broker" instances and CHATS in db/redis then this "dequeue" operation will become a congestion point. One idea would be to define clear "configuration" tiers, e.g. based on GPU memory requirements and have independent task queues for them.


work_request = inference.WorkRequest(
conversation=chat.conversation,
model_name=request.model_name,
max_new_tokens=request.max_new_tokens,
)

logger.info(f"Created {work_request}")
try:
await websocket.send_text(work_request.json())
while True:
# maybe unnecessary to parse and re-serialize
# could just pass the raw string and mark end via empty string
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
await redisClient.rpush(chat.id, response_packet.json())
if response_packet.is_end:
break
except fastapi.WebSocketException:
# TODO: handle this better
logger.exception(f"Websocket closed during handling of {chat.id}")

chat.message_request_state = MessageRequestState.complete
6 changes: 6 additions & 0 deletions inference/server/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fastapi[all]
loguru
pydantic
redis
sse-starlette
websockets
40 changes: 40 additions & 0 deletions inference/text-client/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Simple REPL frontend."""

import json

import requests
import sseclient
import typer

app = typer.Typer()


@app.command()
def main(backend_url: str = "http://127.0.0.1:8000"):
"""Simple REPL client."""
chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"]
while True:
message = typer.prompt("User").strip()

# wait for stream to be ready
# could implement a queue position indicator
# could be implemented with long polling
# but server load needs to be considered
response = requests.post(
f"{backend_url}/chat/{chat_id}/message",
json={"message": message},
stream=True,
headers={"Accept": "text/event-stream"},
)
response.raise_for_status()

client = sseclient.SSEClient(response)
print("Assistant: ", end="", flush=True)
for event in client.events():
data = json.loads(event.data)
print(data["token"], end="", flush=True)
print()


if __name__ == "__main__":
app()
3 changes: 3 additions & 0 deletions inference/text-client/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
requests
sseclient-py
typer
79 changes: 79 additions & 0 deletions inference/worker/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import re
import time

import rel
import torch
import typer
import websocket
from loguru import logger
from oasst_shared.schemas import inference, protocol
from transformers import pipeline

app = typer.Typer()


@app.command()
def main(
backend_url: str = "ws://localhost:8000",
model_name: str = "distilgpt2",
):
pipe = pipeline("text-generation", model=model_name)

def on_open(ws: websocket.WebSocket):
worker_config = inference.WorkerConfig(model_name=model_name)
ws.send(worker_config.json())

def on_message(ws: websocket.WebSocket, message: str):
# TODO: what if this comes in, but one is already in progress?
# also need to think of enabling batching
work_request = inference.WorkRequest.parse_raw(message)

def _prepare_message(message: protocol.ConversationMessage) -> str:
prefix = "Assistant: " if message.is_assistant else "User: "
return prefix + message.text

# construct prompt
messages = [_prepare_message(message) for message in work_request.conversation.messages]

prompt = "\n".join(messages) + "\nAssistant:"

# TODO: replace this with incremental generation
torch.manual_seed(work_request.seed)
model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[
0
]["generated_text"]
model_output = model_output.strip()

# fake streaming
split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)]
pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
for piece in pieces:
if not piece:
continue
if piece.strip() in ("User:", "Assistant:"):
break
ws.send(inference.WorkResponsePacket(token=piece).json())
time.sleep(0.1)
ws.send(inference.WorkResponsePacket(is_end=True).json())

def on_error(ws: websocket.WebSocket, error: Exception):
logger.error(f"Connection error: {error}")

def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str):
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}")

ws = websocket.WebSocketApp(
f"{backend_url}/work",
on_message=on_message,
on_error=on_error,
on_close=on_close,
on_open=on_open,
)

ws.run_forever(dispatcher=rel, reconnect=5)
rel.signal(2, rel.abort)
rel.dispatch()


if __name__ == "__main__":
app()
6 changes: 6 additions & 0 deletions inference/worker/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
loguru
rel
torch
transformers
typer
websocket-client