Skip to content

Commit

Permalink
Merge pull request #4 from IlyasMoutawwakil/use-semaphore
Browse files Browse the repository at this point in the history
Added semaphores and event loop management
  • Loading branch information
IlyasMoutawwakil committed Mar 6, 2024
2 parents 651e62e + d44fb9d commit bfe39c7
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 38 deletions.
10 changes: 6 additions & 4 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from py_txi.text_generation_inference import TGI, TGIConfig

embed = TEI(config=TEIConfig(pooling="cls"))
output = embed.encode(["Hi, I'm an embedding model", "I'm fine, how are you?"])
print("Embed:", output)
output = embed.encode(["Hi, I'm an embedding model", "I'm fine, how are you?"] * 100)
print(len(output))
print("Embed:", output[0])
embed.close()

llm = TGI(config=TGIConfig(sharded="false"))
output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"])
print("LLM:", output)
output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"] * 50)
print(len(output))
print("LLM:", output[0])
llm.close()
12 changes: 12 additions & 0 deletions py_txi/docker_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class DockerInferenceServerConfig:
)

timeout: int = 60
max_concurrent_requests: int = 128

def __post_init__(self) -> None:
if self.ports["80/tcp"][1] == 0:
Expand Down Expand Up @@ -125,6 +126,13 @@ def __init__(self, config: DockerInferenceServerConfig) -> None:
else:
LOGGER.info(f"\t {log}")

try:
asyncio.set_event_loop(asyncio.get_event_loop())
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())

self.semaphore = asyncio.Semaphore(self.config.max_concurrent_requests)

LOGGER.info(f"\t+ Waiting for {self.NAME} server to be ready")
start_time = time.time()
while time.time() - start_time < self.config.timeout:
Expand Down Expand Up @@ -153,6 +161,10 @@ def close(self) -> None:
LOGGER.info("\t+ Docker container stopped")
del self.container

if hasattr(self, "semaphore"):
self.semaphore
del self.semaphore

if hasattr(self, "client"):
del self.client

Expand Down
10 changes: 6 additions & 4 deletions py_txi/text_embedding_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
DType_Literal = Literal["float32", "float16"]


@dataclass(order=False)
@dataclass
class TEIConfig(DockerInferenceServerConfig):
# Docker options
image: str = "ghcr.io/huggingface/text-embeddings-inference:cpu-latest"
Expand All @@ -24,7 +24,8 @@ class TEIConfig(DockerInferenceServerConfig):
revision: str = "main"
dtype: Optional[DType_Literal] = None
pooling: Optional[Pooling_Literal] = None
tokenization_workers: Optional[int] = None
# Concurrency options
max_concurrent_requests: int = 512

def __post_init__(self) -> None:
super().__post_init__()
Expand All @@ -45,8 +46,9 @@ def __init__(self, config: TEIConfig) -> None:
super().__init__(config)

async def single_client_call(self, text: str, **kwargs) -> np.ndarray:
output = await self.client.feature_extraction(text=text, **kwargs)
return output
async with self.semaphore:
output = await self.client.feature_extraction(text=text, **kwargs)
return output

async def batch_client_call(self, text: List[str], **kwargs) -> List[np.ndarray]:
output = await asyncio.gather(*[self.single_client_call(t, **kwargs) for t in text])
Expand Down
39 changes: 9 additions & 30 deletions py_txi/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,15 @@ class TGIConfig(DockerInferenceServerConfig):
# Launcher options
model_id: str = "gpt2"
revision: str = "main"
num_shard: Optional[int] = None
dtype: Optional[DType_Literal] = None
quantize: Optional[Quantize_Literal] = None
enable_cuda_graphs: Optional[bool] = None
sharded: Optional[Shareded_Literal] = None
num_shard: Optional[int] = None
trust_remote_code: Optional[bool] = None
quantize: Optional[Quantize_Literal] = None
disable_custom_kernels: Optional[bool] = None
# Inference options
max_best_of: Optional[int] = None
max_concurrent_requests: Optional[int] = None
max_stop_sequences: Optional[int] = None
max_top_n_tokens: Optional[int] = None
max_input_length: Optional[int] = None
max_total_tokens: Optional[int] = None
waiting_served_ratio: Optional[float] = None
max_batch_prefill_tokens: Optional[int] = None
max_batch_total_tokens: Optional[int] = None
max_waiting_tokens: Optional[int] = None
max_batch_size: Optional[int] = None
enable_cuda_graphs: Optional[bool] = None
huggingface_hub_cache: Optional[str] = None
weights_cache_override: Optional[str] = None
cuda_memory_fraction: Optional[float] = None
rope_scaling: Optional[str] = None
rope_factor: Optional[str] = None
json_output: Optional[bool] = None
otlp_endpoint: Optional[str] = None
cors_allow_origin: Optional[list] = None
watermark_gamma: Optional[str] = None
watermark_delta: Optional[str] = None
tokenizer_config_path: Optional[str] = None
disable_grammar_support: Optional[bool] = None
trust_remote_code: Optional[bool] = None
# Concurrency options
max_concurrent_requests: int = 128

def __post_init__(self) -> None:
super().__post_init__()
Expand All @@ -72,8 +50,9 @@ def __init__(self, config: TGIConfig) -> None:
super().__init__(config)

async def single_client_call(self, prompt: str, **kwargs) -> str:
output = await self.client.text_generation(prompt=prompt, **kwargs)
return output
async with self.semaphore:
output = await self.client.text_generation(prompt=prompt, **kwargs)
return output

async def batch_client_call(self, prompt: list, **kwargs) -> list:
output = await asyncio.gather(*[self.single_client_call(prompt=p, **kwargs) for p in prompt])
Expand Down

0 comments on commit bfe39c7

Please sign in to comment.