Skip to content

Commit

Permalink
Merge pull request #6 from IlyasMoutawwakil/update-2.0
Browse files Browse the repository at this point in the history
Updates for TGI 2.0
  • Loading branch information
IlyasMoutawwakil committed Apr 13, 2024
2 parents 182092b + 59f8217 commit 567e43e
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ jobs:
- name: Run test
run: |
make test
make test_cpu
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ style:
ruff format .
ruff check --fix .

test:
pytest tests/ -x
test_cpu:
pytest tests/ -s -x -k "cpu"

test_gpu:
pytest tests/ -s -x -k "gpu"

install:
pip install -e .
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,34 @@ pip install py-txi

Py-TXI is designed to be used in a similar way to Transformers API. We use `docker-py` (instead of a dirty `subprocess` solution) so that the containers you run are linked to the main process and are stopped automatically when your code finishes or fails.

## Advantages

- **Easy to use**: Py-TXI is designed to be used in a similar way to Transformers API.
- **Automatic cleanup**: Py-TXI stops the Docker container when your code finishes or fails.
- **Batched inference**: Py-TXI supports sending a batch of inputs to the server for inference.
- **Automatic port allocation**: Py-TXI automatically allocates a free port for the Inference server.
- **Configurable**: Py-TXI allows you to configure the Inference servers using a simple configuration object.
- **Verbose**: Py-TXI streams the logs of the underlying Docker container to the main process so you can debug easily.

## Usage

Here's an example of how to use it:

```python
from py_txi import TGI, TGIConfig

llm = TGI(config=TGIConfig(sharded="false"))
llm = TGI(config=TGIConfig(model_id="bigscience/bloom-560m", gpus="0"))
output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"])
print("LLM:", output)
llm.close()
```

Output: ```LLM: ["er. I'm a language modeler. I'm a language modeler. I'm a language", " I'm fine, how are you? I'm fine, how are you? I'm fine,"]```
Output: ```LLM: [' student. I have a problem with the following code. I have a class that has a method that', '"\n\n"I\'m fine," said the girl, "but I don\'t want to be alone.']```

```python
from py_txi import TEI, TEIConfig

embed = TEI(config=TEIConfig(pooling="cls"))
embed = TEI(config=TEIConfig(model_id="BAAI/bge-base-en-v1.5"))
output = embed.encode(["Hi, I'm an embedding model", "I'm fine, how are you?"])
print("Embed:", output)
embed.close()
Expand Down
16 changes: 8 additions & 8 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from py_txi.text_embedding_inference import TEI, TEIConfig
from py_txi.text_generation_inference import TGI, TGIConfig

embed = TEI(config=TEIConfig(pooling="cls"))
output = embed.encode(["Hi, I'm an embedding model", "I'm fine, how are you?"] * 100)
llm = TGI(config=TGIConfig(model_id="bigscience/bloom-560m", gpus="0"))
output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"])
print(len(output))
print("Embed:", output[0])
embed.close()
print("LLM:", output)
llm.close()

llm = TGI(config=TGIConfig(sharded="false"))
output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"] * 50)
embed = TEI(config=TEIConfig(model_id="BAAI/bge-base-en-v1.5"))
output = embed.encode(["Hi, I'm an embedding model", "I'm fine, how are you?"])
print(len(output))
print("LLM:", output[0])
llm.close()
# print("Embed:", output)
embed.close()
50 changes: 35 additions & 15 deletions py_txi/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,48 @@
import time
from abc import ABC
from dataclasses import asdict, dataclass, field
from logging import INFO, basicConfig, getLogger
from logging import INFO, getLogger
from typing import Any, Dict, List, Optional, Union

import coloredlogs
import docker
import docker.errors
import docker.types
from huggingface_hub import AsyncInferenceClient

from .utils import get_free_port
from .utils import get_free_port, styled_logs

basicConfig(level=INFO)
coloredlogs.install(level=INFO, fmt="[%(asctime)s][%(filename)s][%(levelname)s] %(message)s")

DOCKER = docker.from_env()
LOGGER = getLogger("Inference-Server")


@dataclass
class InferenceServerConfig:
# Common options
model_id: str
revision: Optional[str] = "main"
# Image to use for the container
image: str
image: Optional[str] = None
# Shared memory size for the container
shm_size: str = "1g"
shm_size: Optional[str] = None
# List of custom devices to forward to the container e.g. ["/dev/kfd", "/dev/dri"] for ROCm
devices: Optional[List[str]] = None
# NVIDIA-docker GPU device options e.g. "all" (all) or "0,1,2,3" (ids) or 4 (count)
gpus: Optional[Union[str, int]] = None

ports: Dict[str, Any] = field(
default_factory=lambda: {"80/tcp": ("127.0.0.1", 0)},
default_factory=lambda: {"80/tcp": ("0.0.0.0", 0)},
metadata={"help": "Dictionary of ports to expose from the container."},
)
volumes: Dict[str, Any] = field(
default_factory=lambda: {os.path.expanduser("~/.cache/huggingface/hub"): {"bind": "/data", "mode": "rw"}},
metadata={"help": "Dictionary of volumes to mount inside the container."},
)
environment: Dict[str, str] = field(
default_factory=lambda: {"HUGGINGFACE_HUB_TOKEN": os.environ.get("HUGGINGFACE_HUB_TOKEN", "")},
metadata={"help": "Dictionary of environment variables to forward to the container."},
environment: List[str] = field(
default_factory=lambda: ["HUGGINGFACE_HUB_TOKEN"],
metadata={"help": "List of environment variables to forward to the container."},
)

max_concurrent_requests: Optional[int] = None
Expand All @@ -52,6 +56,10 @@ def __post_init__(self) -> None:
LOGGER.info("\t+ Getting a free port for the server")
self.ports["80/tcp"] = (self.ports["80/tcp"][0], get_free_port())

if self.shm_size is None:
LOGGER.warning("\t+ Shared memory size not provided. Defaulting to '1g'.")
self.shm_size = "1g"


class InferenceServer(ABC):
NAME: str = "Inference-Server"
Expand Down Expand Up @@ -97,8 +105,15 @@ def __init__(self, config: InferenceServerConfig) -> None:
else:
self.command.append(f"--{k.replace('_', '-')}={str(v).lower()}")

address, port = self.config.ports["80/tcp"]
self.url = f"http://{address}:{port}"
self.command.append("--json-output")

LOGGER.info(f"\t+ Building {self.NAME} environment")
self.environment = {}
for key in self.config.environment:
if key in os.environ:
self.environment[key] = os.environ[key]
else:
LOGGER.warning(f"\t+ Environment variable {key} not found in the system")

LOGGER.info(f"\t+ Running {self.NAME} container")
self.container = DOCKER.containers.run(
Expand All @@ -107,7 +122,7 @@ def __init__(self, config: InferenceServerConfig) -> None:
volumes=self.config.volumes,
devices=self.config.devices,
shm_size=self.config.shm_size,
environment=self.config.environment,
environment=self.environment,
device_requests=self.device_requests,
command=self.command,
auto_remove=True,
Expand All @@ -117,14 +132,19 @@ def __init__(self, config: InferenceServerConfig) -> None:
LOGGER.info(f"\t+ Streaming {self.NAME} server logs")
for line in self.container.logs(stream=True):
log = line.decode("utf-8").strip()
log = styled_logs(log)

if self.SUCCESS_SENTINEL.lower() in log.lower():
LOGGER.info(f"\t {log}")
LOGGER.info(f"\t+ {log}")
break
elif self.FAILURE_SENTINEL.lower() in log.lower():
LOGGER.info(f"\t {log}")
LOGGER.info(f"\t+ {log}")
raise Exception(f"{self.NAME} server failed to start")
else:
LOGGER.info(f"\t {log}")
LOGGER.info(f"\t+ {log}")

address, port = self.config.ports["80/tcp"]
self.url = f"http://{address}:{port}"

try:
asyncio.set_event_loop(asyncio.get_event_loop())
Expand Down
21 changes: 13 additions & 8 deletions py_txi/text_embedding_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@

@dataclass
class TEIConfig(InferenceServerConfig):
# Docker options
image: str = "ghcr.io/huggingface/text-embeddings-inference:cpu-latest"
# Launcher options
model_id: str = "bert-base-uncased"
revision: str = "main"
dtype: Optional[DType_Literal] = None
pooling: Optional[Pooling_Literal] = None
# Concurrency options
Expand All @@ -30,11 +26,20 @@ class TEIConfig(InferenceServerConfig):
def __post_init__(self) -> None:
super().__post_init__()

if self.image is None:
if is_nvidia_system() and self.gpus is not None:
LOGGER.info("\t+ Using the latest NVIDIA GPU image for Text-Embedding-Inference")
self.image = "ghcr.io/huggingface/text-embeddings-inference:latest"
else:
LOGGER.info("\t+ Using the latest CPU image for Text-Embedding-Inference")
self.image = "ghcr.io/huggingface/text-embeddings-inference:cpu-latest"

if is_nvidia_system() and "cpu" in self.image:
LOGGER.warning(
"Your system has NVIDIA GPU, but you are using a CPU image."
"Consider using a GPU image for better performance."
)
LOGGER.warning("\t+ You are running on a NVIDIA GPU system but using a CPU image.")

if self.pooling is None:
LOGGER.warning("\t+ Pooling strategy not provided. Defaulting to 'cls' pooling.")
self.pooling = "cls"


class TEI(InferenceServer):
Expand Down
29 changes: 17 additions & 12 deletions py_txi/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,21 @@
from typing import Literal, Optional, Union

from .inference_server import InferenceServer, InferenceServerConfig
from .utils import is_rocm_system
from .utils import is_nvidia_system, is_rocm_system

LOGGER = getLogger("Text-Generation-Inference")

Shareded_Literal = Literal["true", "false"]
DType_Literal = Literal["float32", "float16", "bfloat16"]
Quantize_Literal = Literal["bitsandbytes-nf4", "bitsandbytes-fp4", "gptq"]
Quantize_Literal = Literal["bitsandbytes-nf4", "bitsandbytes-fp4", "gptq", "awq", "eetq", "fp8"]


@dataclass
class TGIConfig(InferenceServerConfig):
# Docker options
image: str = "ghcr.io/huggingface/text-generation-inference:latest"
# Launcher options
model_id: str = "gpt2"
revision: str = "main"
num_shard: Optional[int] = None
cuda_graphs: Optional[int] = None
dtype: Optional[DType_Literal] = None
enable_cuda_graphs: Optional[bool] = None
sharded: Optional[Shareded_Literal] = None
quantize: Optional[Quantize_Literal] = None
disable_custom_kernels: Optional[bool] = None
Expand All @@ -33,12 +29,21 @@ class TGIConfig(InferenceServerConfig):
def __post_init__(self) -> None:
super().__post_init__()

if self.image is None:
if is_nvidia_system() and self.gpus is not None:
LOGGER.info("\t+ Using the latest NVIDIA GPU image for Text-Generation-Inference")
self.image = "ghcr.io/huggingface/text-generation-inference:latest"
elif is_rocm_system() and self.devices is not None:
LOGGER.info("\t+ Using the latest ROCm AMD GPU image for Text-Generation-Inference")
self.image = "ghcr.io/huggingface/text-generation-inference:latest-rocm"
else:
raise ValueError(
"Unsupported system. Please either provide the image to use explicitly "
"or use a supported system (NVIDIA/ROCm) while specifying gpus/devices."
)

if is_rocm_system() and "rocm" not in self.image:
LOGGER.warning(
"You are running on a ROCm system but the image is not rocm specific. "
"Add 'rocm' to the image name to use the rocm specific image."
)
self.image += "-rocm"
LOGGER.warning("\t+ You are running on a ROCm AMD GPU system but using a non-ROCM image.")


class TGI(InferenceServer):
Expand Down
37 changes: 37 additions & 0 deletions py_txi/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import socket
import subprocess
from datetime import datetime
from json import loads


def get_free_port() -> int:
Expand All @@ -22,3 +24,38 @@ def is_nvidia_system() -> bool:
return True
except FileNotFoundError:
return False


LEVEL_TO_MESSAGE_STYLE = {
"DEBUG": "\033[37m",
"INFO": "\033[37m",
"WARN": "\033[33m",
"WARNING": "\033[33m",
"ERROR": "\033[31m",
"CRITICAL": "\033[31m",
}
TIMESTAMP_STYLE = "\033[32m"
TARGET_STYLE = "\033[0;38"
LEVEL_STYLE = "\033[1;30m"


def color_text(text: str, color: str) -> str:
return f"{color}{text}\033[0m"


def styled_logs(log: str) -> str:
dict_log = loads(log)

fields = dict_log.get("fields", {})
level = dict_log.get("level", "could not parse level")
target = dict_log.get("target", "could not parse target")
timestamp = dict_log.get("timestamp", "could not parse timestamp")
message = fields.get("message", dict_log.get("message", "could not parse message"))
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ").strftime("%Y-%m-%d %H:%M:%S")

message = color_text(message, LEVEL_TO_MESSAGE_STYLE.get(level, "\033[37m"))
timestamp = color_text(timestamp, TIMESTAMP_STYLE)
target = color_text(target, TARGET_STYLE)
level = color_text(level, LEVEL_STYLE)

return f"[{timestamp}][{target}][{level}] - {message}"
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from setuptools import find_packages, setup

PY_TXI_VERSION = "0.6.0"
PY_TXI_VERSION = "0.7.0"

common_setup_kwargs = {
"author": "Ilyas Moutawwakil",
Expand All @@ -24,7 +24,7 @@
name="py-txi",
version=PY_TXI_VERSION,
packages=find_packages(),
install_requires=["docker", "huggingface-hub", "numpy", "aiohttp"],
install_requires=["docker", "huggingface-hub", "numpy", "aiohttp", "coloredlogs"],
extras_require={"quality": ["ruff"], "testing": ["pytest"]},
**common_setup_kwargs,
)
9 changes: 5 additions & 4 deletions tests/test_txi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
from py_txi import TEI, TGI, TEIConfig, TGIConfig


def test_tei():
embed = TEI(config=TEIConfig(pooling="cls"))
def test_cpu_tei():
embed = TEI(config=TEIConfig(model_id="BAAI/bge-base-en-v1.5"))
output = embed.encode("Hi, I'm a language model")
assert isinstance(output, np.ndarray)
output = embed.encode(["Hi, I'm a language model", "I'm fine, how are you?"])
assert isinstance(output, list) and all(isinstance(x, np.ndarray) for x in output)
embed.close()


def test_tgi():
llm = TGI(config=TGIConfig(sharded="false"))
# tested locally with gpu
def test_gpu_tgi():
llm = TGI(config=TGIConfig(model_id="bigscience/bloom-560m", gpus="0"))
output = llm.generate("Hi, I'm a sanity test")
assert isinstance(output, str)
output = llm.generate(["Hi, I'm a sanity test", "I'm a second sentence"])
Expand Down

0 comments on commit 567e43e

Please sign in to comment.