From 4af52491dfe35cb85a7e09f634b0087c9b5cd48e Mon Sep 17 00:00:00 2001 From: harrisonyhq Date: Wed, 19 Nov 2025 19:44:48 -0800 Subject: [PATCH 1/5] [Feat] Support launch from config file --- examples/offline_inference.py | 3 +- examples/ucm_config_example.yaml | 32 ++++++++++ ucm/integration/vllm/uc_connector.py | 22 ++----- ucm/integration/vllm/ucm_connector.py | 25 +++----- ucm/utils.py | 92 +++++++++++++++++++++++++++ 5 files changed, 139 insertions(+), 35 deletions(-) create mode 100644 examples/ucm_config_example.yaml create mode 100644 ucm/utils.py diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 320f04c8..5a2fea37 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -22,8 +22,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str): kv_connector_module_path=module_path, kv_role="kv_both", kv_connector_extra_config={ - "ucm_connector_name": "UcmNfsStore", - "ucm_connector_config": {"storage_backends": "/home/share/wc/nfs"}, + "UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml" }, ) diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml new file mode 100644 index 00000000..d8c29714 --- /dev/null +++ b/examples/ucm_config_example.yaml @@ -0,0 +1,32 @@ +# UCM Configuration File Example +# +# This file demonstrates how to configure UCM using YAML. +# You can use this config file by setting the path to this file in kv_connector_extra_config in launch script or command line like this: +# kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} +# +# Alternatively, you can still use kv_connector_extra_config in KVTransferConfig +# for backward compatibility. + +# Connector name (e.g., "UcmNfsStore", "UcmDramStore") +ucm_connector_name: "UcmNfsStore" + +# Connector-specific configuration +ucm_connector_config: + storage_backends: "/mnt/test" + +# Sparse attention configuration +# Format 1: Dictionary format (for methods like ESA, KvComp) +# ucm_sparse_config: +# ESA: +# init_window_sz: 1 +# local_window_sz: 2 +# min_blocks: 4 +# sparse_ratio: 0.3 +# retrieval_stride: 5 + # Or for GSA: + # GSA: {} + + +# Whether to use layerwise loading/saving (optional, default: True for UnifiedCacheConnectorV1) +# use_layerwise: true + diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index 2689e0c1..8f6e4925 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -44,6 +44,7 @@ from ucm.logger import init_logger from ucm.store.factory import UcmConnectorFactory from ucm.store.ucmstore import Task +from ucm.utils import UCMConfig if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -113,22 +114,11 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): vllm_config.parallel_config ) self.head_size = vllm_config.model_config.get_head_size() - if ( - self._vllm_config.kv_transfer_config is not None - and "ucm_connector_name" - in self._vllm_config.kv_transfer_config.kv_connector_extra_config - ): - name = self._vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_name" - ] - config = {} - if ( - "ucm_connector_config" - in self._vllm_config.kv_transfer_config.kv_connector_extra_config - ): - config = self._vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_config" - ] + ucm_config = UCMConfig(vllm_config.kv_transfer_config) + launch_config = ucm_config.get_config() + if "ucm_connector_name" in launch_config: + name = launch_config.get("ucm_connector_name") + config = launch_config.get("ucm_connector_config") or {} config["device"] = self.rank config["role"] = ( "scheduler" if role == KVConnectorRole.SCHEDULER else "worker" diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 3c1ade22..cfd39dc9 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -17,6 +17,7 @@ from ucm.logger import init_logger from ucm.store.factory import UcmConnectorFactory from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.utils import UCMConfig if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -120,24 +121,12 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # save block info, avoid hash request twice self.request_meta: dict[str, ReqMeta] = {} + ucm_config = UCMConfig(vllm_config.kv_transfer_config) + launch_config = ucm_config.get_config() - # TODO use yaml - if ( - vllm_config.kv_transfer_config is not None - and "ucm_connector_name" - in vllm_config.kv_transfer_config.kv_connector_extra_config - ): - name = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_name" - ] - config = {} - if ( - "ucm_connector_config" - in vllm_config.kv_transfer_config.kv_connector_extra_config - ): - config = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_config" - ] + if "ucm_connector_name" in launch_config: + name = launch_config.get("ucm_connector_name") + config = launch_config.get("ucm_connector_config") or {} config["device"] = self.rank config["role"] = ( "scheduler" if role == KVConnectorRole.SCHEDULER else "worker" @@ -168,6 +157,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): config["kv_block_size"] / 1024 / 1024, config["io_size"] / 1024, ) + else: + raise TypeError(f"no storage connector name in config.") def get_num_new_matched_tokens( self, diff --git a/ucm/utils.py b/ucm/utils.py new file mode 100644 index 00000000..68fd4f00 --- /dev/null +++ b/ucm/utils.py @@ -0,0 +1,92 @@ +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +from typing import Any, Dict + +import yaml + +from ucm.logger import init_logger + +logger = init_logger(__name__) + + +class UCMConfig: + def __init__(self, kv_transfer_config: Any): + self.kv_transfer_config = kv_transfer_config + self.config: Dict[str, Any] = {} + self._load_config() + def load_ucm_config_from_yaml(self, file_path: str) -> Dict[str, Any]: + if not file_path: + logger.warning("No UCM config file path provided.") + return {} + + try: + with open(file_path, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) or {} + if not isinstance(config, dict): + logger.warning( + f"Config file {file_path} does not contain a dictionary. " + "Returning empty config." + ) + return {} + logger.info(f"Loaded UCM config from {file_path}") + return config + except FileNotFoundError: + logger.error(f"UCM config file not found: {file_path}") + return {} + except yaml.YAMLError as e: + logger.error(f"Failed to parse YAML config file {file_path}: {e}") + return {} + + def _load_config(self) -> None: + has_extra_config = ( + self.kv_transfer_config is not None + and hasattr(self.kv_transfer_config, "kv_connector_extra_config") + and self.kv_transfer_config.kv_connector_extra_config is not None + ) + if not has_extra_config: + self.config = self._get_default_config() + else: + extra_config = self.kv_transfer_config.kv_connector_extra_config + if "UCM_CONFIG_FILE" in extra_config: + config_file = extra_config["UCM_CONFIG_FILE"] + self.config = self.load_ucm_config_from_yaml(config_file) + else: + if extra_config == {}: + self.config = self._get_default_config() + else: + self.config = dict(extra_config) + logger.info("Using kv_connector_extra_config from terminal input") + + + def _get_default_config(self) -> Dict[str, Any]: + config = { + "ucm_connector_name": "UcmDramStore" + } + logger.warning(f"No UCM config provided, using default configuration {config}") + return config + + def get_config(self) -> Dict[str, Any]: + logger.info(f"Using UCM with config: {self.config}") + return self.config \ No newline at end of file From b109092036a2bbfae45b1f5f4e6a8be8bdb002bf Mon Sep 17 00:00:00 2001 From: harrisonyhq Date: Thu, 20 Nov 2025 00:55:36 -0800 Subject: [PATCH 2/5] [Docs] Update documents for launch with yaml --- docs/source/getting-started/quick_start.md | 18 ++++++++++----- .../user-guide/prefix-cache/dram_store.md | 23 +++++++++---------- .../user-guide/prefix-cache/nfs_store.md | 21 +++++++++-------- ucm/integration/vllm/uc_connector.py | 4 ++-- ucm/integration/vllm/ucm_connector.py | 4 ++-- ucm/utils.py | 16 ++++++------- 6 files changed, 45 insertions(+), 41 deletions(-) diff --git a/docs/source/getting-started/quick_start.md b/docs/source/getting-started/quick_start.md index 9e7630e1..1f33ab3b 100644 --- a/docs/source/getting-started/quick_start.md +++ b/docs/source/getting-started/quick_start.md @@ -59,7 +59,17 @@ First, specify the python hash seed by: export PYTHONHASHSEED=123456 ``` -Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model: +Create a config yaml like following and save it to your own directory: +```yaml +# UCM Configuration File Example +# Refer to file unified-cache-management/examples/ucm_config_example.yaml for more details +ucm_connector_name: "UcmNfsStore" + +ucm_connector_config: + storage_backends: "/mnt/test" +``` + +Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model and your config file path: ```bash # Change the model path to your own model path @@ -77,11 +87,7 @@ vllm serve ${MODEL_PATH} \ "kv_connector_module_path": "ucm.integration.vllm.uc_connector", "kv_role": "kv_both", "kv_connector_extra_config": { - "ucm_connector_name": "UcmDramStore", - "ucm_connector_config": { - "max_cache_size": 5368709120, - "kv_block_size": 262144 - } + "UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml" } }' ``` diff --git a/docs/source/user-guide/prefix-cache/dram_store.md b/docs/source/user-guide/prefix-cache/dram_store.md index 1be2f30a..157e098e 100644 --- a/docs/source/user-guide/prefix-cache/dram_store.md +++ b/docs/source/user-guide/prefix-cache/dram_store.md @@ -49,10 +49,15 @@ To use the DRAM connector, you need to configure the `connector_config` dictiona ### Example: -```python -# Allocate up to 8GB DRAM for KV cache -# KV Block size (in byte) is 262144 -kv_connector_extra_config={"ucm_connector_name": "UcmDramStore", "ucm_connector_config":{"max_cache_size": 5368709120, "kv_block_size": 262144}} +Create a config yaml like following and save it to your own directory: +```yaml +# UCM Configuration File Example +# Refer to file unified-cache-management/examples/ucm_config_example.yaml for more details +ucm_connector_name: "UcmDramStore" + +ucm_connector_config: + max_cache_size: 5368709120 + kv_block_size: 262144 ``` ## Launching Inference @@ -65,7 +70,7 @@ To start **offline inference** with the DRAM connector,modify the script `exam # In examples/offline_inference.py ktc = KVTransferConfig( ... - kv_connector_extra_config={"ucm_connector_name": "UcmDramStore", "ucm_connector_config":{"max_cache_size": 5368709120, "kv_block_size": 262144}} + kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} ) ``` @@ -99,13 +104,7 @@ vllm serve /home/models/Qwen2.5-14B-Instruct \ "kv_connector": "UnifiedCacheConnectorV1", "kv_connector_module_path": "ucm.integration.vllm.uc_connector", "kv_role": "kv_both", - "kv_connector_extra_config": { - "ucm_connector_name": "UcmDramStore", - "ucm_connector_config": { - "max_cache_size": 5368709120, - "kv_block_size": 262144 - } - } + "kv_connector_extra_config": {"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} }' ``` diff --git a/docs/source/user-guide/prefix-cache/nfs_store.md b/docs/source/user-guide/prefix-cache/nfs_store.md index b581acf5..741fcedf 100644 --- a/docs/source/user-guide/prefix-cache/nfs_store.md +++ b/docs/source/user-guide/prefix-cache/nfs_store.md @@ -87,8 +87,15 @@ To use the NFS connector, you need to configure the `connector_config` dictionar ### Example: -```python -kv_connector_extra_config={"ucm_connector_name": "UcmNfsStore", "ucm_connector_config":{"storage_backends": "/mnt/test1", "transferStreamNumber": 32}} +Create a config yaml like following and save it to your own directory: +```yaml +# UCM Configuration File Example +# Refer to file unified-cache-management/examples/ucm_config_example.yaml for more details +ucm_connector_name: "UcmNfsStore" + +ucm_connector_config: + storage_backends: "/mnt/test" + transferStreamNumber: 32 ``` ## Launching Inference @@ -101,7 +108,7 @@ To start **offline inference** with the NFS connector,modify the script `examp # In examples/offline_inference.py ktc = KVTransferConfig( ... - kv_connector_extra_config={"ucm_connector_name": "UcmNfsStore", "ucm_connector_config":{"storage_backends": "/mnt/test1", "transferStreamNumber": 32}} + kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} ) ``` @@ -131,13 +138,7 @@ vllm serve /home/models/Qwen2.5-14B-Instruct \ "kv_connector": "UnifiedCacheConnectorV1", "kv_connector_module_path": "ucm.integration.vllm.uc_connector", "kv_role": "kv_both", - "kv_connector_extra_config": { - "ucm_connector_name": "UcmNfsStore", - "ucm_connector_config": { - "storage_backends": "/mnt/test", - "transferStreamNumber":32 - } - } + "kv_connector_extra_config": {"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} }' ``` diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index 8f6e4925..7952bf30 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -44,7 +44,7 @@ from ucm.logger import init_logger from ucm.store.factory import UcmConnectorFactory from ucm.store.ucmstore import Task -from ucm.utils import UCMConfig +from ucm.utils import Config if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -114,7 +114,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): vllm_config.parallel_config ) self.head_size = vllm_config.model_config.get_head_size() - ucm_config = UCMConfig(vllm_config.kv_transfer_config) + ucm_config = Config(vllm_config.kv_transfer_config) launch_config = ucm_config.get_config() if "ucm_connector_name" in launch_config: name = launch_config.get("ucm_connector_name") diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index cfd39dc9..76a5853b 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -17,7 +17,7 @@ from ucm.logger import init_logger from ucm.store.factory import UcmConnectorFactory from ucm.store.ucmstore import Task, UcmKVStoreBase -from ucm.utils import UCMConfig +from ucm.utils import Config if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -121,7 +121,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # save block info, avoid hash request twice self.request_meta: dict[str, ReqMeta] = {} - ucm_config = UCMConfig(vllm_config.kv_transfer_config) + ucm_config = Config(vllm_config.kv_transfer_config) launch_config = ucm_config.get_config() if "ucm_connector_name" in launch_config: diff --git a/ucm/utils.py b/ucm/utils.py index 68fd4f00..bf07f6b8 100644 --- a/ucm/utils.py +++ b/ucm/utils.py @@ -31,16 +31,17 @@ logger = init_logger(__name__) -class UCMConfig: +class Config: def __init__(self, kv_transfer_config: Any): self.kv_transfer_config = kv_transfer_config self.config: Dict[str, Any] = {} self._load_config() + def load_ucm_config_from_yaml(self, file_path: str) -> Dict[str, Any]: if not file_path: logger.warning("No UCM config file path provided.") return {} - + try: with open(file_path, "r", encoding="utf-8") as f: config = yaml.safe_load(f) or {} @@ -58,7 +59,7 @@ def load_ucm_config_from_yaml(self, file_path: str) -> Dict[str, Any]: except yaml.YAMLError as e: logger.error(f"Failed to parse YAML config file {file_path}: {e}") return {} - + def _load_config(self) -> None: has_extra_config = ( self.kv_transfer_config is not None @@ -79,14 +80,11 @@ def _load_config(self) -> None: self.config = dict(extra_config) logger.info("Using kv_connector_extra_config from terminal input") - def _get_default_config(self) -> Dict[str, Any]: - config = { - "ucm_connector_name": "UcmDramStore" - } + config = {"ucm_connector_name": "UcmDramStore"} logger.warning(f"No UCM config provided, using default configuration {config}") return config - + def get_config(self) -> Dict[str, Any]: logger.info(f"Using UCM with config: {self.config}") - return self.config \ No newline at end of file + return self.config From 6fc1ecb1f2d9ed2cf861912141277be0abeb8286 Mon Sep 17 00:00:00 2001 From: harrisonyhq Date: Thu, 20 Nov 2025 01:34:50 -0800 Subject: [PATCH 3/5] [Fix] Change load only on first rank into configuration --- examples/ucm_config_example.yaml | 2 ++ ucm/integration/vllm/ucm_connector.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml index d8c29714..cb6f6d17 100644 --- a/examples/ucm_config_example.yaml +++ b/examples/ucm_config_example.yaml @@ -13,6 +13,8 @@ ucm_connector_name: "UcmNfsStore" # Connector-specific configuration ucm_connector_config: storage_backends: "/mnt/test" + transferIoDirect: false + load_only_first_rank: false # Sparse attention configuration # Format 1: Dictionary format (for methods like ESA, KvComp) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 76a5853b..3f2cd553 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -109,12 +109,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.block_size = self._vllm_config.cache_config.block_size self.is_mla = self._vllm_config.model_config.is_deepseek_mla - self.load_only_first_rank = self.is_mla - if self.is_mla: - if role == KVConnectorRole.WORKER: - self.group_coordinator = get_tp_group() - self.broadcast_fn = self.group_coordinator.broadcast - self.broadcast_stream = torch.cuda.Stream() self.store: UcmKVStoreBase self.request_hasher = RequestHasher() @@ -149,6 +143,14 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): config["io_size"] = block_size_per_layer * ( 1 if self.is_mla else num_head_per_tp ) + self.load_only_first_rank: bool = config.get( + "load_only_first_rank", self.is_mla + ) + if self.load_only_first_rank: + if role == KVConnectorRole.WORKER: + self.group_coordinator = get_tp_group() + self.broadcast_fn = self.group_coordinator.broadcast + self.broadcast_stream = torch.cuda.Stream() self.store = UcmConnectorFactory.create_connector(name, config) logger.info("init UCConnectorImpl, connector: %s", name) From 2f51ccd7e3b7180de7218fc423c5a5aeca375dc3 Mon Sep 17 00:00:00 2001 From: harrisonyhq Date: Thu, 20 Nov 2025 03:09:02 -0800 Subject: [PATCH 4/5] [Feat] Add support for hit ratio in yaml --- examples/ucm_config_example.yaml | 1 + ucm/integration/vllm/ucm_connector.py | 12 +++++------- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml index cb6f6d17..562f8ab9 100644 --- a/examples/ucm_config_example.yaml +++ b/examples/ucm_config_example.yaml @@ -31,4 +31,5 @@ ucm_connector_config: # Whether to use layerwise loading/saving (optional, default: True for UnifiedCacheConnectorV1) # use_layerwise: true +# hit_ratio: 0.9 diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 9c32805e..8b55a951 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -110,11 +110,11 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # save block info, avoid hash request twice, and track them until request finished self.requests_meta: dict[str, RequestMeta] = {} ucm_config = Config(vllm_config.kv_transfer_config) - launch_config = ucm_config.get_config() + self.launch_config = ucm_config.get_config() - if "ucm_connector_name" in launch_config: - name = launch_config.get("ucm_connector_name") - config = launch_config.get("ucm_connector_config") or {} + if "ucm_connector_name" in self.launch_config: + name = self.launch_config.get("ucm_connector_name") + config = self.launch_config.get("ucm_connector_config") or {} config["device"] = self.rank config["role"] = ( "scheduler" if role == KVConnectorRole.SCHEDULER else "worker" @@ -624,9 +624,7 @@ class UCMMockConnector(UCMDirectConnector): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config, role) - self._hit_ratio = float( - self._vllm_config.kv_transfer_config.kv_connector_extra_config["hit_ratio"] - ) + self._hit_ratio = float(self.launch_config["hit_ratio"]) logger.info(f"hit_ratio: {self._hit_ratio}") def get_num_new_matched_tokens( From 767062dd174cc17b5c1b4dbc5e72e90ed02c4f02 Mon Sep 17 00:00:00 2001 From: harrisonyhq Date: Thu, 20 Nov 2025 03:16:50 -0800 Subject: [PATCH 5/5] [Fix] Fix load only first rank in non mla scene --- ucm/integration/vllm/ucm_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 8b55a951..10c6785e 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -137,8 +137,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): config["io_size"] = block_size_per_layer * ( 1 if self.is_mla else num_head_per_tp ) - self.load_only_first_rank: bool = config.get( - "load_only_first_rank", self.is_mla + self.load_only_first_rank: bool = ( + config.get("load_only_first_rank", self.is_mla) and self.is_mla ) if self.load_only_first_rank: if role == KVConnectorRole.WORKER: