Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions integration_tests/sample_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import time
import grpc
from google.protobuf import json_format
from text_generation_tests.pb import generation_pb2_grpc as gpb2, generation_pb2 as pb2


def get_streaming_response_tgis(response):
stop = False
generated_tokens = 0
while not stop:
try:
x = next(response)
timestamp = time.time_ns()
data = json_format.MessageToDict(x)
# skip first response (tokenizer output only)
if "inputTokenCount" not in data:
n_tokens = data["generatedTokenCount"] - generated_tokens
generated_tokens = data["generatedTokenCount"]
yield data, n_tokens, timestamp, True, None
except Exception as e:
timestamp = time.time_ns()
yield None, 0, timestamp, False, e


channel = grpc.insecure_channel("localhost:8033")
stub = gpb2.GenerationServiceStub(channel)
max_new_tokens = 100

template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"
num_req = 0
while True:
prompt_input = input(f"\n{num_req}) Enter a prompt:\n")

print("-" * 40)
print("Output:")
prompt = template.format(prompt_input)
sample_request = {
"model_id": "dummy-model-name",
"request": {"text": prompt},
"params": {
"method": "GREEDY",
"stopping": {
"max_new_tokens": max_new_tokens,
"min_new_tokens": max_new_tokens,
},
},
}
message = json_format.ParseDict(sample_request, pb2.SingleGenerationRequest())
output = []
total_time = 0
response = stub.GenerateStream(message)
response_generator = get_streaming_response_tgis(response)
t0 = time.time_ns()
response = ""
stop = False
while not stop:
r, n_tokens, t, ok, err = next(response_generator)

if not ok:
stop = True
# check if we have reached end of stream
if type(err) is StopIteration:
continue
duration = (t - t0) / 1000.0 / 1000.0
record = {
"response": r,
"ok": ok,
"error": str(err),
"timestamp": t,
"duration_ms": duration,
"n_tokens": n_tokens,
}
total_time += duration
response += r["text"]
output.append(record)
t0 = t

# print(json.dumps(output, indent=4))
print("-" * 40)
print(response)
print("-" * 40)
print(f"Total_time : {total_time}ms")
print(f"Time_per_token : {total_time/max_new_tokens}ms")
print("-" * 40)
num_req += 1
17 changes: 13 additions & 4 deletions router/src/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -839,10 +839,19 @@ impl<'a> TokenProcessor<'a> {
let request_id = output.request_id;
let next_token_id = output.token_id;

let e = self
.entries
.get_mut(&request_id)
.expect("ID not found. This is a bug.");
let e = self.entries.get_mut(&request_id);

// if a client cancelled a request and speculative decoding is
// enabled, it's possible that the request will get removed
// from entries table, but there can still be tokens in outputs stream
// corresponding to that request. ideally we could defer removing
// the request_id from the entries table until all tokens have been
// processed...but for now let's just ignore them.
if e.is_none() {
continue;
}

let e = e.unwrap();

let is_stream = e.stream_tx.is_some();
let stop_seqs = &e.request.parameters.stop_seqs;
Expand Down
2 changes: 1 addition & 1 deletion server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ onnxruntime-gpu = { version = "^1.17.1", optional = true }
onnx = { version = "^1.16.0", optional = true }
einops = "^0.7.0"
ibm-fms = { version = "^0.0", optional = true }
fms-extras = { git = "https://github.com/foundation-model-stack/fms-extras", rev = "fdb1636de4261fd4102da659ab45d3fcc33fe8ef", optional = true }
fms-extras = { git = "https://github.com/foundation-model-stack/fms-extras", rev = "a010516ff2c938c206b9b342b16bd747ef07d43c", optional = true }

# Explicitly install some transitive dependencies to avoid CVEs
jinja2 = ">=3.1.3"
Expand Down
18 changes: 13 additions & 5 deletions server/text_generation_server/inference_engine/tgis_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from transformers.models.auto.auto_factory import _BaseAutoModelClass

from text_generation_server.models import FLASH_ATTENTION
from text_generation_server.models import FLASH_ATTENTION, PAGED_ATTENTION
from text_generation_server.utils import Weights

from text_generation_server.inference_engine import BaseInferenceEngine
Expand Down Expand Up @@ -83,8 +83,12 @@ def __init__(
elif model_type == "gpt_bigcode":
self._config.transpose = self._config.architectures[0].startswith("GPT2")
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import FlashSantacoderForCausalLM
model_class = FlashSantacoderForCausalLM
if PAGED_ATTENTION:
from text_generation_server.models.custom_modeling.paged_santacoder_modeling import PagedSantacoderForCausalLM
model_class = PagedSantacoderForCausalLM
else:
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import FlashSantacoderForCausalLM
model_class = FlashSantacoderForCausalLM

elif model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
if sharded and self._config.alibi:
Expand All @@ -97,8 +101,12 @@ def __init__(
model_class = FlashRWForCausalLM

elif model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import FlashLlamaForCausalLM
model_class = FlashLlamaForCausalLM
if PAGED_ATTENTION:
from text_generation_server.models.custom_modeling.paged_llama_modeling import PagedLlamaForCausalLM
model_class = PagedLlamaForCausalLM
else:
from text_generation_server.models.custom_modeling.flash_llama_modeling import FlashLlamaForCausalLM
model_class = FlashLlamaForCausalLM

self._config.quantize = quantize

Expand Down
35 changes: 34 additions & 1 deletion server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM, PretrainedConfig

FLASH_ATTENTION = os.getenv("FLASH_ATTENTION", "false").lower() == "true"
PAGED_ATTENTION = os.getenv("PAGED_ATTENTION", "false").lower() == "true"

__all__ = ["Model", "CausalLM", "Seq2SeqLM", "get_model", "FLASH_ATTENTION", "PT2_COMPILE"]
__all__ = ["Model", "CausalLM", "Seq2SeqLM", "get_model", "FLASH_ATTENTION", "PAGED_ATTENTION", "PT2_COMPILE"]

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
Expand Down Expand Up @@ -43,6 +44,38 @@ def get_model(
model_config_dict, _kwargs = PretrainedConfig.get_config_dict(model_path)
model_type = model_config_dict["model_type"]

if PAGED_ATTENTION:
print(f"Using Paged Attention")

if deployment_framework != "tgis_native":
print_rank_n(
f"WARNING: Using deployment engine tgis_native rather than {deployment_framework} "
"because PAGED_ATTENTION is enabled"
)
deployment_framework = "tgis_native"

if model_type == "llama":
# Custom config type for LLaMA models
from text_generation_server.models.custom_modeling.paged_llama_modeling import LlamaConfig
model_config = LlamaConfig.from_pretrained(model_path)
elif model_type == "gpt_bigcode":
from transformers import GPTBigCodeConfig
model_config = GPTBigCodeConfig.from_pretrained(model_path)
# num_key_value_heads is used in creating cache, here we add that attribute based on mqa
model_config.num_key_value_heads = 1 if model_config.multi_query else model_config.num_attention_heads
else:
raise NotImplementedError("PAGED_ATTENTION only supported for gpt_bigcode and llama for now")

from text_generation_server.models.paged_causal_lm import PagedCausalLM
return PagedCausalLM(
model_name,
revision,
deployment_framework,
dtype, quantize,
model_config,
max_sequence_length=max_sequence_length,
)

if FLASH_ATTENTION:
# This will raise an exception if flash attention is not supported by the device
import text_generation_server.utils.flash_attn as flash_attn
Expand Down
Loading