-
Notifications
You must be signed in to change notification settings - Fork 3.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial implementation of the inference system #869
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# OpenAssitant Inference | ||
|
||
Preliminary implementation of the inference engine for OpenAssistant. | ||
|
||
## Development (you'll need multiple terminals) | ||
|
||
Run a redis container (or use the one of the general docker compose file): | ||
|
||
```bash | ||
docker run --rm -it -p 6379:6379 redis | ||
``` | ||
|
||
Run the inference server: | ||
|
||
```bash | ||
cd server | ||
pip install -r requirements.txt | ||
uvicorn main:app --reload | ||
``` | ||
|
||
Run one (or more) workers: | ||
|
||
```bash | ||
cd worker | ||
pip install -r requirements.txt | ||
python __main__.py | ||
``` | ||
|
||
Run the client: | ||
|
||
```bash | ||
cd text-client | ||
pip install -r requirements.txt | ||
python __main__.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# OpenAssistant Inference Server | ||
|
||
Workers communicate with the `/work` endpoint via Websocket. They provide their | ||
configuration and if a task is available, the server returns it. The worker then | ||
performs the task and returns the result in a streaming fashion to the server, | ||
also via websocket. | ||
|
||
Clients first call `/chat` to make a new chat, then add to that via | ||
`/chat/<id>/message`. The response is a SSE event source, which will send tokens | ||
as they are available. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import asyncio | ||
import enum | ||
import uuid | ||
|
||
import fastapi | ||
import pydantic | ||
import redis.asyncio as redis | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from loguru import logger | ||
from oasst_shared.schemas import inference, protocol | ||
from sse_starlette.sse import EventSourceResponse | ||
|
||
app = fastapi.FastAPI() | ||
|
||
# Allow CORS | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) | ||
|
||
|
||
class Settings(pydantic.BaseSettings): | ||
redis_host: str = "localhost" | ||
redis_port: int = 6379 | ||
redis_db: int = 0 | ||
|
||
sse_retry_timeout: int = 15000 | ||
|
||
|
||
settings = Settings() | ||
|
||
# create async redis client | ||
redisClient = redis.Redis( | ||
host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True | ||
) | ||
|
||
|
||
class CreateChatRequest(pydantic.BaseModel): | ||
pass | ||
|
||
|
||
class CreateChatResponse(pydantic.BaseModel): | ||
id: str | ||
|
||
|
||
class MessageRequest(pydantic.BaseModel): | ||
message: str = pydantic.Field(..., repr=False) | ||
model_name: str = "distilgpt2" | ||
max_new_tokens: int = 100 | ||
|
||
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool: | ||
return self.model_name == worker_config.model_name | ||
|
||
|
||
class TokenResponseEvent(pydantic.BaseModel): | ||
token: str | ||
|
||
|
||
class MessageRequestState(str, enum.Enum): | ||
pending = "pending" | ||
in_progress = "in_progress" | ||
complete = "complete" | ||
|
||
|
||
class DbChatEntry(pydantic.BaseModel): | ||
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4())) | ||
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation) | ||
pending_message_request: MessageRequest | None = None | ||
message_request_state: MessageRequestState | None = None | ||
|
||
|
||
# TODO: make real database | ||
CHATS: dict[str, DbChatEntry] = {} | ||
|
||
|
||
@app.post("/chat") | ||
async def create_chat(request: CreateChatRequest) -> CreateChatResponse: | ||
"""Allows a client to create a new chat.""" | ||
logger.info(f"Received {request}") | ||
chat = DbChatEntry() | ||
CHATS[chat.id] = chat | ||
return CreateChatResponse(id=chat.id) | ||
|
||
|
||
@app.get("/chat/{id}") | ||
async def get_chat(id: str) -> protocol.Conversation: | ||
"""Allows a client to get the current state of a chat.""" | ||
return CHATS[id].conversation | ||
|
||
|
||
@app.post("/chat/{id}/message") | ||
async def create_message(id: str, message_request: MessageRequest, fastapi_request: fastapi.Request): | ||
"""Allows the client to stream the results of a request.""" | ||
|
||
chat = CHATS[id] | ||
if not chat.conversation.is_prompter_turn: | ||
raise fastapi.HTTPException(status_code=400, detail="Not your turn") | ||
if chat.pending_message_request is not None: | ||
raise fastapi.HTTPException(status_code=400, detail="Already pending") | ||
|
||
chat.conversation.messages.append( | ||
protocol.ConversationMessage( | ||
text=message_request.message, | ||
is_assistant=False, | ||
) | ||
) | ||
|
||
chat.pending_message_request = message_request | ||
chat.message_request_state = MessageRequestState.pending | ||
|
||
async def event_generator(): | ||
result_data = [] | ||
|
||
try: | ||
while True: | ||
if await fastapi_request.is_disconnected(): | ||
logger.warning("Client disconnected") | ||
break | ||
|
||
item = await redisClient.blpop(chat.id, 1) | ||
if item is None: | ||
continue | ||
|
||
_, response_packet_str = item | ||
response_packet = inference.WorkResponsePacket.parse_raw(response_packet_str) | ||
result_data.append(response_packet) | ||
|
||
if response_packet.is_end: | ||
break | ||
|
||
yield { | ||
"retry": settings.sse_retry_timeout, | ||
"data": TokenResponseEvent(token=response_packet.token).json(), | ||
} | ||
logger.info(f"Finished streaming {chat.id} {len(result_data)=}") | ||
except Exception: | ||
logger.exception(f"Error streaming {chat.id}") | ||
|
||
chat.conversation.messages.append( | ||
protocol.ConversationMessage( | ||
text="".join([d.token for d in result_data[:-1]]), | ||
is_assistant=True, | ||
) | ||
) | ||
chat.pending_message_request = None | ||
|
||
return EventSourceResponse(event_generator()) | ||
|
||
|
||
@app.websocket("/work") | ||
async def work(websocket: fastapi.WebSocket): | ||
await websocket.accept() | ||
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) | ||
while True: | ||
# find a pending task that matches the worker's config | ||
# could also be implemented using task queues | ||
# but general compatibility matching is tricky | ||
for chat in CHATS.values(): | ||
if (request := chat.pending_message_request) is not None: | ||
if chat.message_request_state == MessageRequestState.pending: | ||
if request.compatible_with(worker_config): | ||
break | ||
else: | ||
logger.debug("No pending tasks") | ||
await asyncio.sleep(1) | ||
continue | ||
|
||
chat.message_request_state = MessageRequestState.in_progress | ||
|
||
work_request = inference.WorkRequest( | ||
conversation=chat.conversation, | ||
model_name=request.model_name, | ||
max_new_tokens=request.max_new_tokens, | ||
) | ||
|
||
logger.info(f"Created {work_request}") | ||
try: | ||
await websocket.send_text(work_request.json()) | ||
while True: | ||
# maybe unnecessary to parse and re-serialize | ||
# could just pass the raw string and mark end via empty string | ||
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) | ||
await redisClient.rpush(chat.id, response_packet.json()) | ||
if response_packet.is_end: | ||
break | ||
except fastapi.WebSocketException: | ||
# TODO: handle this better | ||
logger.exception(f"Websocket closed during handling of {chat.id}") | ||
|
||
chat.message_request_state = MessageRequestState.complete |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
fastapi[all] | ||
loguru | ||
pydantic | ||
redis | ||
sse-starlette | ||
websockets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
"""Simple REPL frontend.""" | ||
|
||
import json | ||
|
||
import requests | ||
import sseclient | ||
import typer | ||
|
||
app = typer.Typer() | ||
|
||
|
||
@app.command() | ||
def main(backend_url: str = "http://127.0.0.1:8000"): | ||
"""Simple REPL client.""" | ||
chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"] | ||
while True: | ||
message = typer.prompt("User").strip() | ||
|
||
# wait for stream to be ready | ||
# could implement a queue position indicator | ||
# could be implemented with long polling | ||
# but server load needs to be considered | ||
response = requests.post( | ||
f"{backend_url}/chat/{chat_id}/message", | ||
json={"message": message}, | ||
stream=True, | ||
headers={"Accept": "text/event-stream"}, | ||
) | ||
response.raise_for_status() | ||
|
||
client = sseclient.SSEClient(response) | ||
print("Assistant: ", end="", flush=True) | ||
for event in client.events(): | ||
data = json.loads(event.data) | ||
print(data["token"], end="", flush=True) | ||
print() | ||
|
||
|
||
if __name__ == "__main__": | ||
app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
requests | ||
sseclient-py | ||
typer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import re | ||
import time | ||
|
||
import rel | ||
import torch | ||
import typer | ||
import websocket | ||
from loguru import logger | ||
from oasst_shared.schemas import inference, protocol | ||
from transformers import pipeline | ||
|
||
app = typer.Typer() | ||
|
||
|
||
@app.command() | ||
def main( | ||
backend_url: str = "ws://localhost:8000", | ||
model_name: str = "distilgpt2", | ||
): | ||
pipe = pipeline("text-generation", model=model_name) | ||
|
||
def on_open(ws: websocket.WebSocket): | ||
worker_config = inference.WorkerConfig(model_name=model_name) | ||
ws.send(worker_config.json()) | ||
|
||
def on_message(ws: websocket.WebSocket, message: str): | ||
# TODO: what if this comes in, but one is already in progress? | ||
# also need to think of enabling batching | ||
work_request = inference.WorkRequest.parse_raw(message) | ||
|
||
def _prepare_message(message: protocol.ConversationMessage) -> str: | ||
prefix = "Assistant: " if message.is_assistant else "User: " | ||
return prefix + message.text | ||
|
||
# construct prompt | ||
messages = [_prepare_message(message) for message in work_request.conversation.messages] | ||
|
||
prompt = "\n".join(messages) + "\nAssistant:" | ||
|
||
# TODO: replace this with incremental generation | ||
torch.manual_seed(work_request.seed) | ||
model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[ | ||
0 | ||
]["generated_text"] | ||
model_output = model_output.strip() | ||
|
||
# fake streaming | ||
split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)] | ||
pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])] | ||
for piece in pieces: | ||
if not piece: | ||
continue | ||
if piece.strip() in ("User:", "Assistant:"): | ||
break | ||
ws.send(inference.WorkResponsePacket(token=piece).json()) | ||
time.sleep(0.1) | ||
ws.send(inference.WorkResponsePacket(is_end=True).json()) | ||
|
||
def on_error(ws: websocket.WebSocket, error: Exception): | ||
logger.error(f"Connection error: {error}") | ||
|
||
def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str): | ||
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}") | ||
|
||
ws = websocket.WebSocketApp( | ||
f"{backend_url}/work", | ||
on_message=on_message, | ||
on_error=on_error, | ||
on_close=on_close, | ||
on_open=on_open, | ||
) | ||
|
||
ws.run_forever(dispatcher=rel, reconnect=5) | ||
rel.signal(2, rel.abort) | ||
rel.dispatch() | ||
|
||
|
||
if __name__ == "__main__": | ||
app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
loguru | ||
rel | ||
torch | ||
transformers | ||
typer | ||
websocket-client |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we have >1 "message-broker" instances and CHATS in db/redis then this "dequeue" operation will become a congestion point. One idea would be to define clear "configuration" tiers, e.g. based on GPU memory requirements and have independent task queues for them.