Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changed basic hf server to support quantization and streaming #2293

Merged
merged 70 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
6afd4fa
Starting multiple workers in inference image when multiple GPUs avail…
yk Mar 31, 2023
a0360fb
added some echoes
yk Mar 31, 2023
8bd03fc
added some echoes
yk Mar 31, 2023
1f2ebdc
more entrypoint stuff
yk Mar 31, 2023
bf1b77d
master port
yk Mar 31, 2023
7419753
adding sleep
yk Apr 1, 2023
ef4a4e0
more configs
yk Apr 2, 2023
9cb9075
Use LLaMA impl of Huggingface Transformers (#2263)
andreaskoepf Mar 31, 2023
cf161e9
Fix GPTNeoX-20B training (#2240)
dvruette Mar 31, 2023
4359c6a
Updated Turkish language (#2270)
irfantogluk Mar 31, 2023
54892e2
Add loader for CodeAlpaca-20k & gpt4all_pruned dataset (#2273)
andreaskoepf Mar 31, 2023
01f94f7
Add support for Cerebras-GPT for training (#2276)
olliestanley Mar 31, 2023
3c1335e
typo in parsing openai/summarize_from_feedback (#2268)
mikegarts Mar 31, 2023
7e05077
Add rng_seed parameter to trainers (#2254)
andreaskoepf Mar 31, 2023
1b72c07
Computing message queue positions (#2235)
yk Mar 31, 2023
e88efb7
Remove assigning eos token id (llama compatibility) (#2280)
andreaskoepf Mar 31, 2023
a167e10
Fix call-to-action responsiveness (#2290)
theopfr Apr 1, 2023
696889c
Added max size to work queue and an error response if full when enque…
yk Apr 1, 2023
7d47021
Fix loading of Nebulous/gpt4all_pruned dataset (#2291)
andreaskoepf Apr 1, 2023
e6ad876
changed basic hf server to support quantization and streaming
yk Apr 2, 2023
9924bab
updated main script
yk Apr 2, 2023
dcf7d43
ctrl c trap
yk Apr 2, 2023
ad12aa7
replacing llama config
yk Apr 2, 2023
90297b2
sleep param
yk Apr 2, 2023
f7ff758
removed dtypes
yk Apr 2, 2023
6e4e0b8
loading in thread
yk Apr 2, 2023
2b7e048
removed signal
yk Apr 2, 2023
417d8ff
removed double start
yk Apr 2, 2023
18de6b7
bugfix
yk Apr 2, 2023
a6e7550
bugfix
yk Apr 2, 2023
8dd1866
exception handling in stream
yk Apr 2, 2023
b915ac4
bug handling
yk Apr 2, 2023
1ea2323
bug handling
yk Apr 2, 2023
710833d
bugfix
yk Apr 2, 2023
57a7eec
bugfix
yk Apr 2, 2023
b9cbe0c
bugfix
yk Apr 2, 2023
b3a0599
bugfix
yk Apr 2, 2023
1ba8929
bugfix
yk Apr 2, 2023
7d0480a
bugfix
yk Apr 2, 2023
c649c92
bugfix
yk Apr 2, 2023
ecf60b9
logging
yk Apr 2, 2023
743e409
bugfix
yk Apr 2, 2023
b66fd96
bugfix
yk Apr 2, 2023
3e0ea1f
bugfix
yk Apr 2, 2023
7610f21
bugfix
yk Apr 2, 2023
0835395
bugfix
yk Apr 2, 2023
fd926b5
bugfix
yk Apr 2, 2023
e16a990
logging
yk Apr 2, 2023
9d35bfa
logging
yk Apr 2, 2023
d0e4412
vocab size fix
yk Apr 2, 2023
5a44c93
vocab size fix
yk Apr 2, 2023
86ea260
vocab size fix
yk Apr 2, 2023
8139699
vocab size fix
yk Apr 2, 2023
b27b6c4
vocab size fix
yk Apr 2, 2023
072d1d7
vocab size fix
yk Apr 2, 2023
1b9d0c3
vocab size fix
yk Apr 2, 2023
9caf954
vocab size fix
yk Apr 2, 2023
e733abd
vocab size fix
yk Apr 2, 2023
ed38c87
decode hack
yk Apr 2, 2023
17e43bd
more fixes
yk Apr 2, 2023
6b5788e
added back token hack
yk Apr 2, 2023
9701b0d
removed logging
yk Apr 2, 2023
2ad7820
feedback
yk Apr 2, 2023
f517214
decode fix
yk Apr 2, 2023
000f12c
warmup change
yk Apr 2, 2023
9047652
torch threads
yk Apr 2, 2023
82506ea
model loading fix
yk Apr 2, 2023
d27e712
delaying tokens by 1
yk Apr 3, 2023
54e9994
Merge branch 'main' into hf-worker-server-bnb
yk Apr 3, 2023
3195604
feedback
yk Apr 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions inference/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ def main():

if model_config.is_lorem:
tokenizer = None
elif model_config.is_llama:
tokenizer: transformers.PreTrainedTokenizer = transformers.LlamaTokenizer.from_pretrained(model_config.model_id)
else:
tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {tokenizer.vocab_size}")

while True:
try:
Expand Down Expand Up @@ -76,6 +75,9 @@ def main():
for ftr in done:
ftr.result()
message = ws.recv()
if not message:
logger.warning("Connection closed, reconnecting...")
break
worker_request = pydantic.parse_raw_as(inference.WorkerRequest, message)
match worker_request.request_type:
case "work":
Expand Down
199 changes: 144 additions & 55 deletions inference/worker/basic_hf_server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
# a basic fastapi server to run generation on HF models

import signal
import sys
import threading
from queue import Queue

import fastapi
import hf_streamer
import interface
import torch
import transformers
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from oasst_shared import model_configs
from settings import settings
from sse_starlette.sse import EventSourceResponse

app = fastapi.FastAPI()

DECODE_TOKEN = "<decode-token>"


# Allow CORS
app.add_middleware(
Expand All @@ -35,78 +41,161 @@ async def log_exceptions(request: fastapi.Request, call_next):
return response


def terminate_server(signum, frame):
logger.warning(f"Signal {signum}. Terminating server...")
sys.exit(0)


model: transformers.PreTrainedModel
tokenizer: transformers.PreTrainedTokenizer


@app.on_event("startup")
async def load_models():
global model, tokenizer
signal.signal(signal.SIGINT, terminate_server)
model_loaded: bool = False
fully_loaded: bool = False
model_input_queue: Queue = Queue()


def model_thread():
model: transformers.PreTrainedModel
tokenizer: transformers.PreTrainedTokenizer
model, tokenizer, decode_token = load_models()

request: interface.GenerateStreamRequest
output_queue: Queue
eos_token_id = tokenizer.eos_token_id if hasattr(tokenizer, "eos_token_id") else None
while True:
request, output_queue = model_input_queue.get()
try:
prompt = request.inputs
params = request.parameters.dict()
seed = params.pop("seed")
params.pop("stop")
params.pop("details")

last_token_id = None # need to delay by 1 to simulate tgi

def print_text(token_id: int):
nonlocal last_token_id
if last_token_id is not None:
text = decode_token(last_token_id)
stream_response = interface.GenerateStreamResponse(
token=interface.Token(text=text, id=last_token_id),
)
output_queue.put_nowait(stream_response)
last_token_id = token_id

with torch.no_grad():
ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False)
streamer = hf_streamer.HFStreamer(input_ids=ids, printer=print_text)
ids = ids.to(model.device)
output = model.generate(ids, **params, streamer=streamer, eos_token_id=eos_token_id)
output = output.cpu()
output_ids = output[0][len(ids[0]) :]
decoded = tokenizer.decode(output_ids, skip_special_tokens=True)

stream_response = interface.GenerateStreamResponse(
token=interface.Token(
text=decode_token(last_token_id), # hack because the "normal" inference server does this at once
id=last_token_id,
),
generated_text=decoded.strip(),
details=interface.StreamDetails(
finish_reason="eos_token",
generated_tokens=len(output_ids),
seed=seed,
),
)
output_queue.put_nowait(stream_response)
except Exception as e:
logger.exception("Exception in model thread")
output_queue.put_nowait(interface.GenerateStreamResponse(error=str(e)))


def load_models():
global model_loaded

torch.set_num_threads(1)
torch.set_num_interop_threads(1)

model_config = model_configs.MODEL_CONFIGS.get(settings.model_config_name)
if model_config is None:
logger.error(f"Unknown model config name: {settings.model_config_name}")
sys.exit(2)

hf_config = transformers.AutoConfig.from_pretrained(model_config.model_id)
logger.warning(f"Loading model {model_config.model_id}...")
if model_config.is_llama:
config = transformers.LlamaConfig.from_pretrained(model_config.model_id)
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_config.model_id)
model = transformers.LlamaForCausalLM.from_pretrained(model_config.model_id, torch_dtype=config.torch_dtype)
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
model = transformers.AutoModelForCausalLM.from_pretrained(model_config.model_id)
if torch.cuda.is_available():
logger.warning("Using GPU")
model = model.cuda()
logger.warning("Model loaded")
signal.signal(signal.SIGINT, signal.SIG_DFL)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
logger.warning(f"tokenizer {tokenizer.name_or_path} has vocab size {tokenizer.vocab_size}")

# see `decode_token` method, taken from HF text-generation-inference
tokenizer.add_special_tokens({"additional_special_tokens": ["<decode-token>"]})

special_decode_token_id = tokenizer.convert_tokens_to_ids("<decode-token>")
special_decode_token_length = len("<decode-token>")

def decode_token(token_id):
result = tokenizer.decode([special_decode_token_id, token_id], skip_special_tokens=False)
# slice to remove special decode token
return result[special_decode_token_length:]

config_dtype = hf_config.torch_dtype if hasattr(hf_config, "torch_dtype") else torch.float32
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else config_dtype

model = transformers.AutoModelForCausalLM.from_pretrained(
model_config.model_id,
torch_dtype=dtype,
load_in_8bit=settings.quantize,
device_map="auto" if torch.cuda.is_available() else None,
).eval()
logger.warning("Model loaded, using it once...")

# warmup
with torch.no_grad():
text = "Hello, world"
tokens = tokenizer.encode(text, return_tensors="pt")
tokens = tokens.to(model.device)
model.generate(tokens, max_length=10, num_beams=1, do_sample=False)

model_loaded = True

return model, tokenizer, decode_token


@app.on_event("startup")
async def use_model_once():
logger.warning("Generating once to warm up the model...")
await generate(
interface.GenerateStreamRequest(
inputs="Hello world",
parameters=interface.GenerateStreamParameters(
max_new_tokens=10,
),
)
)
logger.warning("Model warmed up")
async def start_model_thread():
logger.warning("Starting model thread...")
threading.Thread(target=model_thread, daemon=True).start()
logger.warning("Model thread started")


@app.on_event("startup")
async def welcome_message():
global fully_loaded
logger.warning("Server started")
logger.warning("To stop the server, press Ctrl+C")


@app.post("/generate")
async def generate(request: interface.GenerateStreamRequest):
global model, tokenizer
prompt = request.inputs
params = request.parameters.dict()
params.pop("seed")
params.pop("stop")
params.pop("details")
with torch.no_grad():
ids = tokenizer.encode(prompt, return_tensors="pt")
ids = ids.to(model.device)
output = model.generate(ids, **params)
output = output.cpu()
output_ids = output[0][len(ids[0]) :]
decoded = tokenizer.decode(output_ids, skip_special_tokens=True)
return {"text": decoded.strip()}
fully_loaded = True


@app.post("/generate_stream")
async def generate(
request: interface.GenerateStreamRequest,
):
def event_stream():
try:
output_queue: Queue = Queue()
model_input_queue.put_nowait((request, output_queue))
while True:
output = output_queue.get() # type: interface.GenerateStreamResponse
yield {"data": output.json()}
if output.is_end:
break
if output.is_error:
raise Exception(output.error)
except Exception as e:
logger.exception("Exception in event stream")
output_queue.put_nowait(interface.GenerateStreamResponse(error=str(e)))
raise

return EventSourceResponse(event_stream())


@app.get("/health")
async def health():
if not (fully_loaded and model_loaded):
raise fastapi.HTTPException(status_code=503, detail="Server not fully loaded")
return {"status": "ok"}


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
21 changes: 21 additions & 0 deletions inference/worker/download_model_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
import signal
import sys
from pathlib import Path

import huggingface_hub


def terminate(signum, frame):
print("Terminating...")
sys.exit(0)


if __name__ == "__main__":
signal.signal(signal.SIGINT, terminate)
model_id = os.getenv("MODEL_ID")
snapshot_dir = Path(huggingface_hub.snapshot_download(model_id))
for file in snapshot_dir.rglob("*.json"):
text = file.read_text()
text = text.replace("LLaMA", "Llama")
file.write_text(text)
36 changes: 36 additions & 0 deletions inference/worker/hf_streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import typing

import transformers
from loguru import logger


class Printer(typing.Protocol):
def __call__(self, value: int) -> None:
...


def _unpack(value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("HFStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
return value.cpu().tolist()


# based on HF text streamer
class HFStreamer(transformers.generation.streamers.BaseStreamer):
def __init__(self, input_ids, printer: Printer):
self.input_ids = _unpack(input_ids)[::-1]
self.printer = printer

def put(self, value):
for token_id in _unpack(value):
if self.input_ids:
input_id = self.input_ids.pop()
if input_id != token_id:
logger.warning(f"Input id {input_id} does not match output id {token_id}")
else:
self.printer(token_id)

def end(self):
pass
12 changes: 6 additions & 6 deletions inference/worker/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
class GenerateStreamParameters(pydantic.BaseModel):
max_new_tokens: int = 1024
do_sample: bool = True
top_k: int | None
top_p: float | None
typical_p: float | None
temperature: float | None
repetition_penalty: float | None
seed: int | None
top_k: int | None = None
top_p: float | None = None
typical_p: float | None = None
temperature: float | None = None
repetition_penalty: float | None = None
seed: int | None = None
stop: list[str] = []
details: bool = True

Expand Down
5 changes: 5 additions & 0 deletions inference/worker/requirements-hf.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
accelerate
bitsandbytes
fastapi
huggingface_hub
sse-starlette
torch
git+https://github.com/huggingface/transformers@main#egg=transformers
uvicorn
3 changes: 3 additions & 0 deletions inference/worker/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,8 @@ class Settings(pydantic.BaseSettings):
perform_oom_test: bool = False
oom_test_max_length: int | None = None

# for hf basic server
quantize: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a more descriptive setting name here so people don't expect it to have an effect when not using the basic server?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be called quantize because the hf-inference server also expects it to be called like this



settings = Settings()
20 changes: 15 additions & 5 deletions inference/worker/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,15 @@ def handle_work_request(

if model_config.is_llama:
generated_ids.append(token.id)
with tokenizer_lock:
text = tokenizer.decode(generated_ids)
new_text = text[len(decoded_text) :]
if not decoded_text:
new_text = new_text.lstrip()
try:
with tokenizer_lock:
text = tokenizer.decode(generated_ids, skip_special_tokens=True)
new_text = text[len(decoded_text) :]
if not decoded_text:
new_text = new_text.lstrip()
except Exception:
text = decoded_text
new_text = ""
token.text = new_text
decoded_text = text

Expand Down Expand Up @@ -190,6 +194,12 @@ def get_inference_server_stream_events(request: interface.GenerateStreamRequest)

client = sseclient.SSEClient(response)
for event in client.events():
if event.event == "error":
logger.error(f"Error from inference server: {event.data}")
yield interface.GenerateStreamResponse(error=event.data)
raise RuntimeError(f"Error from inference server: {event.data}")
if event.event == "ping":
continue
stream_response = interface.GenerateStreamResponse.parse_raw(event.data)
yield stream_response

Expand Down
Loading