Skip to content

Commit

Permalink
changed basic hf server to support quantization and streaming (#2293)
Browse files Browse the repository at this point in the history
I now also put the multi-worker-image PR in here because it builds
heavily upon it

---------

Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
Co-authored-by: Dimitri <dimitrivr@icloud.com>
Co-authored-by: IRFN <irfantogluk@gmail.com>
Co-authored-by: AbdBarho <ka70911@gmail.com>
Co-authored-by: Oliver Stanley <olivergestanley@gmail.com>
Co-authored-by: Michael Gartsbein <mikegarts@users.noreply.github.com>
Co-authored-by: mishka <gartsocial@gmail.com>
Co-authored-by: Theodor Peifer <teddypeifer@gmail.com>
  • Loading branch information
9 people committed Apr 3, 2023
1 parent 74cdb54 commit 8a97cd4
Show file tree
Hide file tree
Showing 10 changed files with 308 additions and 79 deletions.
6 changes: 4 additions & 2 deletions inference/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ def main():

if model_config.is_lorem:
tokenizer = None
elif model_config.is_llama:
tokenizer: transformers.PreTrainedTokenizer = transformers.LlamaTokenizer.from_pretrained(model_config.model_id)
else:
tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {tokenizer.vocab_size}")

while True:
try:
Expand Down Expand Up @@ -76,6 +75,9 @@ def main():
for ftr in done:
ftr.result()
message = ws.recv()
if not message:
logger.warning("Connection closed, reconnecting...")
break
worker_request = pydantic.parse_raw_as(inference.WorkerRequest, message)
match worker_request.request_type:
case "work":
Expand Down
199 changes: 144 additions & 55 deletions inference/worker/basic_hf_server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
# a basic fastapi server to run generation on HF models

import signal
import sys
import threading
from queue import Queue

import fastapi
import hf_streamer
import interface
import torch
import transformers
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from oasst_shared import model_configs
from settings import settings
from sse_starlette.sse import EventSourceResponse

app = fastapi.FastAPI()

DECODE_TOKEN = "<decode-token>"


# Allow CORS
app.add_middleware(
Expand All @@ -35,78 +41,161 @@ async def log_exceptions(request: fastapi.Request, call_next):
return response


def terminate_server(signum, frame):
logger.warning(f"Signal {signum}. Terminating server...")
sys.exit(0)


model: transformers.PreTrainedModel
tokenizer: transformers.PreTrainedTokenizer


@app.on_event("startup")
async def load_models():
global model, tokenizer
signal.signal(signal.SIGINT, terminate_server)
model_loaded: bool = False
fully_loaded: bool = False
model_input_queue: Queue = Queue()


def model_thread():
model: transformers.PreTrainedModel
tokenizer: transformers.PreTrainedTokenizer
model, tokenizer, decode_token = load_models()

request: interface.GenerateStreamRequest
output_queue: Queue
eos_token_id = tokenizer.eos_token_id if hasattr(tokenizer, "eos_token_id") else None
while True:
request, output_queue = model_input_queue.get()
try:
prompt = request.inputs
params = request.parameters.dict()
seed = params.pop("seed")
params.pop("stop")
params.pop("details")

last_token_id = None # need to delay by 1 to simulate tgi

def print_text(token_id: int):
nonlocal last_token_id
if last_token_id is not None:
text = decode_token(last_token_id)
stream_response = interface.GenerateStreamResponse(
token=interface.Token(text=text, id=last_token_id),
)
output_queue.put_nowait(stream_response)
last_token_id = token_id

with torch.no_grad():
ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False)
streamer = hf_streamer.HFStreamer(input_ids=ids, printer=print_text)
ids = ids.to(model.device)
output = model.generate(ids, **params, streamer=streamer, eos_token_id=eos_token_id)
output = output.cpu()
output_ids = output[0][len(ids[0]) :]
decoded = tokenizer.decode(output_ids, skip_special_tokens=True)

stream_response = interface.GenerateStreamResponse(
token=interface.Token(
text=decode_token(last_token_id), # hack because the "normal" inference server does this at once
id=last_token_id,
),
generated_text=decoded.strip(),
details=interface.StreamDetails(
finish_reason="eos_token",
generated_tokens=len(output_ids),
seed=seed,
),
)
output_queue.put_nowait(stream_response)
except Exception as e:
logger.exception("Exception in model thread")
output_queue.put_nowait(interface.GenerateStreamResponse(error=str(e)))


def load_models():
global model_loaded

torch.set_num_threads(1)
torch.set_num_interop_threads(1)

model_config = model_configs.MODEL_CONFIGS.get(settings.model_config_name)
if model_config is None:
logger.error(f"Unknown model config name: {settings.model_config_name}")
sys.exit(2)

hf_config = transformers.AutoConfig.from_pretrained(model_config.model_id)
logger.warning(f"Loading model {model_config.model_id}...")
if model_config.is_llama:
config = transformers.LlamaConfig.from_pretrained(model_config.model_id)
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_config.model_id)
model = transformers.LlamaForCausalLM.from_pretrained(model_config.model_id, torch_dtype=config.torch_dtype)
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
model = transformers.AutoModelForCausalLM.from_pretrained(model_config.model_id)
if torch.cuda.is_available():
logger.warning("Using GPU")
model = model.cuda()
logger.warning("Model loaded")
signal.signal(signal.SIGINT, signal.SIG_DFL)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
logger.warning(f"tokenizer {tokenizer.name_or_path} has vocab size {tokenizer.vocab_size}")

# see `decode_token` method, taken from HF text-generation-inference
tokenizer.add_special_tokens({"additional_special_tokens": ["<decode-token>"]})

special_decode_token_id = tokenizer.convert_tokens_to_ids("<decode-token>")
special_decode_token_length = len("<decode-token>")

def decode_token(token_id):
result = tokenizer.decode([special_decode_token_id, token_id], skip_special_tokens=False)
# slice to remove special decode token
return result[special_decode_token_length:]

config_dtype = hf_config.torch_dtype if hasattr(hf_config, "torch_dtype") else torch.float32
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else config_dtype

model = transformers.AutoModelForCausalLM.from_pretrained(
model_config.model_id,
torch_dtype=dtype,
load_in_8bit=settings.quantize,
device_map="auto" if torch.cuda.is_available() else None,
).eval()
logger.warning("Model loaded, using it once...")

# warmup
with torch.no_grad():
text = "Hello, world"
tokens = tokenizer.encode(text, return_tensors="pt")
tokens = tokens.to(model.device)
model.generate(tokens, max_length=10, num_beams=1, do_sample=False)

model_loaded = True

return model, tokenizer, decode_token


@app.on_event("startup")
async def use_model_once():
logger.warning("Generating once to warm up the model...")
await generate(
interface.GenerateStreamRequest(
inputs="Hello world",
parameters=interface.GenerateStreamParameters(
max_new_tokens=10,
),
)
)
logger.warning("Model warmed up")
async def start_model_thread():
logger.warning("Starting model thread...")
threading.Thread(target=model_thread, daemon=True).start()
logger.warning("Model thread started")


@app.on_event("startup")
async def welcome_message():
global fully_loaded
logger.warning("Server started")
logger.warning("To stop the server, press Ctrl+C")


@app.post("/generate")
async def generate(request: interface.GenerateStreamRequest):
global model, tokenizer
prompt = request.inputs
params = request.parameters.dict()
params.pop("seed")
params.pop("stop")
params.pop("details")
with torch.no_grad():
ids = tokenizer.encode(prompt, return_tensors="pt")
ids = ids.to(model.device)
output = model.generate(ids, **params)
output = output.cpu()
output_ids = output[0][len(ids[0]) :]
decoded = tokenizer.decode(output_ids, skip_special_tokens=True)
return {"text": decoded.strip()}
fully_loaded = True


@app.post("/generate_stream")
async def generate(
request: interface.GenerateStreamRequest,
):
def event_stream():
try:
output_queue: Queue = Queue()
model_input_queue.put_nowait((request, output_queue))
while True:
output = output_queue.get() # type: interface.GenerateStreamResponse
yield {"data": output.json()}
if output.is_end:
break
if output.is_error:
raise Exception(output.error)
except Exception as e:
logger.exception("Exception in event stream")
output_queue.put_nowait(interface.GenerateStreamResponse(error=str(e)))
raise

return EventSourceResponse(event_stream())


@app.get("/health")
async def health():
if not (fully_loaded and model_loaded):
raise fastapi.HTTPException(status_code=503, detail="Server not fully loaded")
return {"status": "ok"}


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
21 changes: 21 additions & 0 deletions inference/worker/download_model_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
import signal
import sys
from pathlib import Path

import huggingface_hub


def terminate(signum, frame):
print("Terminating...")
sys.exit(0)


if __name__ == "__main__":
signal.signal(signal.SIGINT, terminate)
model_id = os.getenv("MODEL_ID")
snapshot_dir = Path(huggingface_hub.snapshot_download(model_id))
for file in snapshot_dir.rglob("*.json"):
text = file.read_text()
text = text.replace("LLaMA", "Llama")
file.write_text(text)
36 changes: 36 additions & 0 deletions inference/worker/hf_streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import typing

import transformers
from loguru import logger


class Printer(typing.Protocol):
def __call__(self, value: int) -> None:
...


def _unpack(value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("HFStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
return value.cpu().tolist()


# based on HF text streamer
class HFStreamer(transformers.generation.streamers.BaseStreamer):
def __init__(self, input_ids, printer: Printer):
self.input_ids = _unpack(input_ids)[::-1]
self.printer = printer

def put(self, value):
for token_id in _unpack(value):
if self.input_ids:
input_id = self.input_ids.pop()
if input_id != token_id:
logger.warning(f"Input id {input_id} does not match output id {token_id}")
else:
self.printer(token_id)

def end(self):
pass
12 changes: 6 additions & 6 deletions inference/worker/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
class GenerateStreamParameters(pydantic.BaseModel):
max_new_tokens: int = 1024
do_sample: bool = True
top_k: int | None
top_p: float | None
typical_p: float | None
temperature: float | None
repetition_penalty: float | None
seed: int | None
top_k: int | None = None
top_p: float | None = None
typical_p: float | None = None
temperature: float | None = None
repetition_penalty: float | None = None
seed: int | None = None
stop: list[str] = []
details: bool = True

Expand Down
5 changes: 5 additions & 0 deletions inference/worker/requirements-hf.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
accelerate
bitsandbytes
fastapi
huggingface_hub
sse-starlette
torch
git+https://github.com/huggingface/transformers@main#egg=transformers
uvicorn
3 changes: 3 additions & 0 deletions inference/worker/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,8 @@ class Settings(pydantic.BaseSettings):
perform_oom_test: bool = False
oom_test_max_length: int | None = None

# for hf basic server
quantize: bool = False


settings = Settings()
20 changes: 15 additions & 5 deletions inference/worker/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,15 @@ def handle_work_request(

if model_config.is_llama:
generated_ids.append(token.id)
with tokenizer_lock:
text = tokenizer.decode(generated_ids)
new_text = text[len(decoded_text) :]
if not decoded_text:
new_text = new_text.lstrip()
try:
with tokenizer_lock:
text = tokenizer.decode(generated_ids, skip_special_tokens=True)
new_text = text[len(decoded_text) :]
if not decoded_text:
new_text = new_text.lstrip()
except Exception:
text = decoded_text
new_text = ""
token.text = new_text
decoded_text = text

Expand Down Expand Up @@ -190,6 +194,12 @@ def get_inference_server_stream_events(request: interface.GenerateStreamRequest)

client = sseclient.SSEClient(response)
for event in client.events():
if event.event == "error":
logger.error(f"Error from inference server: {event.data}")
yield interface.GenerateStreamResponse(error=event.data)
raise RuntimeError(f"Error from inference server: {event.data}")
if event.event == "ping":
continue
stream_response = interface.GenerateStreamResponse.parse_raw(event.data)
yield stream_response

Expand Down
Loading

0 comments on commit 8a97cd4

Please sign in to comment.