Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Thread-based server request throttling #8

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions poet-server/server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
import os
from typing import Literal, Optional
from multiprocessing import Manager
from typing import Literal

import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.applications import JSONResponse
from loguru import logger

from poet import solve

SOLVE_THREADS = min(4, os.cpu_count())
MAX_THREADS = SOLVE_THREADS
NUM_WORKERS = 2 * os.cpu_count() + 1

# makes ANSI color codes work on Windows
os.system("")

app = FastAPI()
# TODO: this doesn't work with workers since each worker has its own threads_used
# see https://stackoverflow.com/questions/65686318/sharing-python-objects-across-multiple-workers/65699375#65699375
# need to implement Redis cache or something similar
threads_used = None


@app.get("/solve")
def solve_handler(
request: Request,
model: Literal[
"linear",
"vgg16",
Expand All @@ -35,8 +47,24 @@ def solve_handler(
time_limit_s: float = 1e100,
solve_threads: int = SOLVE_THREADS, # different default than a direct solve
):
host = request.client.host
if host not in threads_used:
threads_used[host] = 0
if threads_used[host] + solve_threads > MAX_THREADS:
return JSONResponse(
status_code=429,
content={
"detail": "Too many threads requested for solves by this user "
+ f"({threads_used[host]} in use; solving with {solve_threads} more would "
+ f"exceed the max per user of {MAX_THREADS} threads). Please wait until "
+ "some of your solves finish, or retry with less solve_threads."
},
)

threads_used[host] += solve_threads

try:
return solve(
result = solve(
model=model,
platform=platform,
ram_budget=ram_budget,
Expand All @@ -50,11 +78,21 @@ def solve_handler(
solve_threads=solve_threads,
print_graph_info=False,
)
threads_used[host] -= solve_threads
if threads_used[host] == 0:
del threads_used[host]
return result
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
logger.info("Initializing an instance of the POET server.")
uvicorn.run("server:app", host="0.0.0.0", port=80, reload=os.environ.get("DEV"))
uvicorn.run(
"server:app",
host="0.0.0.0",
port=80,
reload=os.environ.get("DEV"),
workers=NUM_WORKERS,
)