diff --git a/delphi/__main__.py b/delphi/__main__.py index 064c34a2..a1cf6ca2 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -145,7 +145,7 @@ async def process_cache( if run_cfg.explainer_provider == "offline": llm_client = Offline( run_cfg.explainer_model, - max_memory=0.9, + max_memory=run_cfg.max_memory, # Explainer models context length - must be able to accommodate the longest # set of examples max_model_len=run_cfg.explainer_model_max_len, diff --git a/delphi/clients/offline.py b/delphi/clients/offline.py index 9dea693f..ecd07d37 100644 --- a/delphi/clients/offline.py +++ b/delphi/clients/offline.py @@ -11,6 +11,7 @@ destroy_distributed_environment, destroy_model_parallel, ) +from vllm.inputs import TokensPrompt from delphi import logger @@ -103,6 +104,7 @@ async def process_func( prompt = self.tokenizer.apply_chat_template( batch, add_generation_prompt=True, tokenize=True ) + prompt = TokensPrompt(prompt_token_ids=prompt) prompts.append(prompt) if self.statistics: non_cached_tokens = len( @@ -121,7 +123,7 @@ async def process_func( None, partial( self.client.generate, # type: ignore - prompt_token_ids=prompts, + prompts, sampling_params=self.sampling_params, use_tqdm=False, ), diff --git a/delphi/config.py b/delphi/config.py index 05b723df..0cf6452e 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -191,6 +191,11 @@ class RunConfig(Serializable): ) """Number of GPUs to use for explanation and scoring.""" + max_memory: float = field( + default=0.9, + ) + """Fraction of GPU memory to allocate to running explainer model.""" + seed: int = field( default=22, ) diff --git a/pyproject.toml b/pyproject.toml index 844cb085..554e8036 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "blobfile", "bitsandbytes", "flask", - "vllm", + "vllm>=0.10.2", "aiofiles", "sentence_transformers", "anyio>=4.8.0",