Skip to content

Commit

Permalink
ENH: skip 4-bit quantization for non-linux or non-cuda local deployme…
Browse files Browse the repository at this point in the history
  • Loading branch information
UranusSeven authored and RayJi01 committed Aug 2, 2023
1 parent 0e522fe commit c409b2b
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 4 deletions.
13 changes: 13 additions & 0 deletions xinference/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ async def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]:
supervisor_ref = await self._get_supervisor_ref()
return await supervisor_ref.get_model(model_uid)

async def is_local_deployment(self) -> bool:
# TODO: temporary.
supervisor_ref = await self._get_supervisor_ref()
return await supervisor_ref.is_local_deployment()


class SyncSupervisorAPI:
def __init__(self, supervisor_address: str):
Expand Down Expand Up @@ -124,3 +129,11 @@ async def _get_model():
return await supervisor_ref.get_model(model_uid)

return self._isolation.call(_get_model())

def is_local_deployment(self) -> bool:
# TODO: temporary.
async def _is_local_deployment():
supervisor_ref = await self._get_supervisor_ref()
return await supervisor_ref.is_local_deployment()

return self._isolation.call(_is_local_deployment())
6 changes: 5 additions & 1 deletion xinference/core/gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,11 @@ def select_model(
progress=gr.Progress(),
):
match_result = match_llm(
_model_name, _model_format, int(_model_size_in_billions), _quantization
_model_name,
_model_format,
int(_model_size_in_billions),
_quantization,
self._api.is_local_deployment(),
)
if not match_result:
raise ValueError(
Expand Down
14 changes: 13 additions & 1 deletion xinference/core/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ async def list_models(self) -> Dict[str, Dict[str, Any]]:
ret.update(await worker.list_models())
return ret

def is_local_deployment(self) -> bool:
# TODO: temporary.
return (
len(self._worker_address_to_worker) == 1
and list(self._worker_address_to_worker)[0] == self.address
)

@log
async def add_worker(self, worker_address: str):
assert worker_address not in self._worker_address_to_worker
Expand Down Expand Up @@ -290,8 +297,13 @@ async def launch_builtin_model(

from ..model.llm import match_llm, match_llm_cls

assert self._supervisor_ref is not None
match_result = match_llm(
model_name, model_format, model_size_in_billions, quantization
model_name,
model_format,
model_size_in_billions,
quantization,
await self._supervisor_ref.is_local_deployment(),
)
if not match_result:
raise ValueError(
Expand Down
40 changes: 38 additions & 2 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import json
import logging
import os
import platform
from typing import List, Optional, Tuple, Type

from .core import LLM
Expand All @@ -29,12 +31,29 @@

LLM_FAMILIES: List["LLMFamilyV1"] = []

logger = logging.getLogger(__name__)


def _is_linux():
return platform.system() == "Linux"


def _has_cuda_device():
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if cuda_visible_devices:
return True
else:
from xorbits._mars.resource import cuda_count

return cuda_count() > 0


def match_llm(
model_name: str,
model_format: Optional[str] = None,
model_size_in_billions: Optional[int] = None,
quantization: Optional[str] = None,
is_local_deployment: bool = False,
) -> Optional[Tuple[LLMFamilyV1, LLMSpecV1, str]]:
"""
Find an LLM family, spec, and quantization that satisfy given criteria.
Expand All @@ -52,8 +71,25 @@ def match_llm(
and quantization not in spec.quantizations
):
continue
# by default, choose the most coarse-grained quantization.
return family, spec, quantization or spec.quantizations[0]
if quantization:
return family, spec, quantization
else:
# by default, choose the most coarse-grained quantization.
# TODO: too hacky.
quantizations = spec.quantizations
quantizations.sort()
for q in quantizations:
if (
is_local_deployment
and not (_is_linux() and _has_cuda_device())
and q == "4-bit"
):
logger.warning(
"Skipping %s for non-linux or non-cuda local deployment .",
q,
)
continue
return family, spec, q
return None


Expand Down

0 comments on commit c409b2b

Please sign in to comment.