Skip to content

Commit

Permalink
re-worked with security in mind
Browse files Browse the repository at this point in the history
  • Loading branch information
yk committed Jan 20, 2023
1 parent 73a1918 commit 3a33827
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 54 deletions.
35 changes: 35 additions & 0 deletions inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# OpenAssitant Inference

Preliminary implementation of the inference engine for OpenAssistant.

## Development (you'll need multiple terminals)

Run a redis container:

```bash
docker run --rm -it -p 6379:6379 redis
```

Run the inference server:

```bash
cd server
pip install -r requirements.txt
uvicorn main:app --reload
```

Run one (or more) workers:

```bash
cd worker
pip install -r requirements.txt
python __main__.py
```

Run the client:

```bash
cd text-client
pip install -r requirements.txt
python __main__.py
```
14 changes: 14 additions & 0 deletions inference/server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# OpenAssistant Inference Server

Workers communicate with the `/work` endpoint. They provide their configuration
and if a task is available, the server returns it. The server also returns the
key of a Redis list where the worker should push the results.

Clients communicate first with the `/complete` endpoint to place a request for
prompt completion. The server returns a unique ID for the request. The client
then polls the `/stream` endpoint with the ID to check if the request has been
assigned to a worker. Once it is assigned, the response will be a SSE event
source.

Notably, `/complete` could be proxied via a frontend, while `/stream` can be
accessed directly by the client, since the unique ID provides enough security.
147 changes: 116 additions & 31 deletions inference/server/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
import random
import uuid

import fastapi
import pydantic
import redis.asyncio as redis
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from oasst_shared.schemas import inference
from sse_starlette.sse import EventSourceResponse

app = FastAPI()
app = fastapi.FastAPI()

# Allow CORS
app.add_middleware(
Expand Down Expand Up @@ -36,39 +37,123 @@ class Settings(pydantic.BaseSettings):
)


@app.get("/stream/{queue_id}")
async def message_stream(queue_id: str, request: Request):
async def event_generator():
while True:
if await request.is_disconnected():
logger.warning("Client disconnected")
break
item = await redisClient.blpop(queue_id, 1)
if item is None:
continue
_, token = item
class CompletionRequest(pydantic.BaseModel):
prompt: str = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
max_length: int = 100

if token == "<END>":
await redisClient.delete(queue_id)
break
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
return self.model_name == worker_config.model_name

yield {
"retry": settings.sse_retry_timeout,
"data": json.dumps({"token": token}),
}

return EventSourceResponse(event_generator())
class CompletionResponse(pydantic.BaseModel):
completion_id: str


class CompleteRequest(pydantic.BaseModel):
text: str
model_name: str = "distilgpt2"
class ResponseEvent(pydantic.BaseModel):
token: str


class DbEntry(pydantic.BaseModel):
completion_id: str
completion_request: CompletionRequest
work_request: inference.WorkRequest | None = None
result_data: list[inference.WorkResponsePacket] | None = None


# TODO: make real database
DATABASE: dict[str, DbEntry] = {}


@app.post("/complete")
async def complete(request: CompleteRequest) -> str:
queue_id = str(uuid.uuid4())
work_queue_name = f"work-{request.model_name}"
logger.info(f"Pushing {queue_id} {len(request.text)=} to {work_queue_name}")
await redisClient.lpush(work_queue_name, json.dumps({"queue_id": queue_id, "text": request.text}))
return queue_id
async def complete(request: CompletionRequest) -> CompletionResponse:
"""Allows a client to request completion of a prompt."""
logger.info(f"Received {request}")
completion_id = str(uuid.uuid4())
DATABASE[completion_id] = DbEntry(completion_id=completion_id, completion_request=request)
return CompletionResponse(completion_id=completion_id)


class StartWorkRequest(pydantic.BaseModel):
worker_config: inference.WorkerConfig


@app.post("/work")
async def work(start_work_request: StartWorkRequest) -> inference.WorkRequest:
"""Allows a worker to request work, given its configuration."""

# find a pending request that is compatible with the worker config
# this could be implemented using queues, databases, long polling, etc.
# we might also think about replacing this endpoint with a message queue
# but we need to know which worker is dequeueing a particular request
# to do proper credit assignment and load balancing (+ security)
for db_entry in DATABASE.values():
if db_entry.completion_request and not db_entry.work_request:
if db_entry.completion_request.compatible_with(start_work_request.worker_config):
break
else:
raise fastapi.HTTPException(status_code=202, detail="No pending requests")

request = db_entry.completion_request

# generate a stream id to use for this request
stream_queue_id = str(uuid.uuid4())
seed = random.randint(0, 2**32 - 1)
work_request = inference.WorkRequest(
stream_queue_id=stream_queue_id,
model_name=request.model_name,
prompt=request.prompt,
max_length=request.max_length,
seed=seed,
)

# store the work request in the database
db_entry.work_request = work_request

logger.info(f"Created {work_request}")
return work_request


@app.get("/stream/{db_id}")
async def message_stream(db_id: str, request: fastapi.Request):
"""Allows the client to stream the results of a request."""

db_entry = DATABASE[db_id]
if not db_entry.work_request:
raise fastapi.HTTPException(status_code=202, detail="Not ready")

stream_queue_id = db_entry.work_request.stream_queue_id

async def event_generator():
result_data = []

try:
while True:
if await request.is_disconnected():
logger.warning("Client disconnected")
break

item = await redisClient.blpop(stream_queue_id, 1)
if item is None:
continue

_, response_packet_str = item
response_packet = inference.WorkResponsePacket.parse_raw(response_packet_str)
result_data.append(response_packet)

if response_packet.is_end:
await redisClient.delete(stream_queue_id)
break

yield {
"retry": settings.sse_retry_timeout,
"data": ResponseEvent(token=response_packet.token).json(),
}
logger.info(f"Finished streaming {stream_queue_id} {len(result_data)=}")
except Exception:
logger.exception(f"Error streaming {stream_queue_id}")

# store the generated data in the database
db_entry.result_data = result_data

return EventSourceResponse(event_generator())
1 change: 1 addition & 0 deletions inference/server/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
fastapi[all]
loguru
pydantic
redis
sse-starlette
18 changes: 14 additions & 4 deletions inference/text-client/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Simple REPL frontend."""

import json
import time

import requests
import sseclient
Expand All @@ -14,10 +15,19 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
"""Simple REPL client."""
while True:
prompt = typer.prompt("Enter text to complete").strip()
queue_id = requests.post(f"{backend_url}/complete", json={"text": prompt}).json()

headers = {"Accept": "text/event-stream"}
response = requests.get(f"{backend_url}/stream/{queue_id}", stream=True, headers=headers)
complete_response = requests.post(f"{backend_url}/complete", json={"prompt": prompt}).json()
completion_id = complete_response["completion_id"]

# wait for stream to be ready
# could implement a queue position indicator
# could be implemented with long polling
# but server load needs to be considered
while True:
headers = {"Accept": "text/event-stream"}
response = requests.get(f"{backend_url}/stream/{completion_id}", stream=True, headers=headers)
if response.status_code == 200:
break
time.sleep(0.25)

client = sseclient.SSEClient(response)
for event in client.events():
Expand Down
70 changes: 51 additions & 19 deletions inference/worker/__main__.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,69 @@
import json
import re
import time

import redis
import requests
import torch
import typer
from loguru import logger
from oasst_shared.schemas import inference
from transformers import pipeline

app = typer.Typer()


@app.command()
def main(model_name: str = "distilgpt2", redis_host: str = "localhost", redis_port: int = 6379, redis_db: int = 0):
pipe = pipeline("text-generation", model=model_name)
def main(
model_name: str = "distilgpt2",
backend_url: str = "http://localhost:8000",
redis_host: str = "localhost",
redis_port: int = 6379,
redis_db: int = 0,
):

redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db, decode_responses=True)

work_queue_name = f"work-{model_name}"
worker_config = inference.WorkerConfig(model_name=model_name)

pipe = pipeline("text-generation", model=model_name)

# TODO: use batching
while True:
item = redis_client.brpop(work_queue_name)
if item is None:
continue
_, work_str = item
work = json.loads(work_str)
queue_id = work["queue_id"]
text = work["text"]
print(f"Processing {queue_id} {len(text)=}...")

# TODO: replace this with incremental generation
model_output = pipe(text, max_length=50, do_sample=True, return_full_text=False)[0]["generated_text"]
for word in model_output.split():
redis_client.rpush(queue_id, word + " ")
time.sleep(0.1)
redis_client.rpush(queue_id, "<END>")

# wait for work to be ready
# could possibly be switched to a message queue
while True:
try:
response = requests.post(f"{backend_url}/work", json={"worker_config": worker_config.dict()})
if response.status_code == 200:
break
except Exception:
logger.exception("Error connecting to backend")
time.sleep(1)

try:
work_request = inference.WorkRequest.parse_raw(response.content)
print(f"Processing {work_request}")
queue_id = work_request.stream_queue_id

# TODO: replace this with incremental generation
torch.manual_seed(work_request.seed)
model_output = pipe(
work_request.prompt, max_length=work_request.max_length, do_sample=True, return_full_text=False
)[0]["generated_text"]
model_output = model_output.strip()

# fake streaming
split_idcs = [m.start() for m in re.finditer(r"(\w+)", model_output)]
pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
for piece in pieces:
if not piece:
continue
redis_client.rpush(queue_id, inference.WorkResponsePacket(token=piece).json())
time.sleep(0.1)
redis_client.rpush(queue_id, inference.WorkResponsePacket(is_end=True).json())
except Exception:
logger.exception(f"Error processing {work_request}")


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions inference/worker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
loguru
redis
requests
torch
transformers
typer
18 changes: 18 additions & 0 deletions oasst-shared/oasst_shared/schemas/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pydantic


class WorkerConfig(pydantic.BaseModel):
model_name: str = "distilgpt2"


class WorkRequest(pydantic.BaseModel):
stream_queue_id: str
prompt: str = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
max_length: int = 100
seed: int = 42


class WorkResponsePacket(pydantic.BaseModel):
token: str | None = None
is_end: bool = False

0 comments on commit 3a33827

Please sign in to comment.