diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 8bec088f..3809f6fd 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -87,13 +87,14 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ) self.block_size = self._vllm_config.cache_config.block_size self.is_mla = self._vllm_config.model_config.is_deepseek_mla + self.is_dsa = False self.kv_cache_dtype: torch.dtype = None if current_platform.is_cuda_alike(): logger.info("CUDA device is available.") torch_dev = torch dev_name = "cuda" - elif current_platform.is_npu(): + elif current_platform.device_type == "npu": logger.info("NPU device is available.") torch_dev = torch.npu dev_name = "npu" @@ -345,8 +346,17 @@ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext" self.kv_caches[layer_name] = attn_layer.kv_cache[ forward_context.virtual_engine ] - if self.kv_cache_dtype is None: - self.kv_cache_dtype = self.kv_caches[layer_name][0].dtype + # Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to + # (2, num_blocks, block_size, num_kv_heads, nope_dim/rope_dim). + # Currently, we treat it as GQA, and use is_dsa to mark it, + # which works but leads to space inefficiency. + # TODO: Optimize this to avoid unnecessary space usage. + sample_kv_layer = next(iter(self.kv_caches.values())) + if self.is_mla and len(sample_kv_layer) == 2: + self.is_mla = False + self.is_dsa = True + if self.kv_cache_dtype is None: + self.kv_cache_dtype = sample_kv_layer[0].dtype @staticmethod def _extract_layer_index(layer_name: str) -> Optional[int]: @@ -491,7 +501,7 @@ def save_kv_layer( def wait_for_save(self) -> None: - if self.is_mla and self.rank != 0: + if (self.is_mla or self.is_dsa) and self.rank != 0: return metadata = self._get_connector_metadata() diff --git a/ucm/store/nfsstore/nfsstore_connector.py b/ucm/store/nfsstore/nfsstore_connector.py index bd30f628..c21c686a 100644 --- a/ucm/store/nfsstore/nfsstore_connector.py +++ b/ucm/store/nfsstore/nfsstore_connector.py @@ -52,14 +52,17 @@ def __init__(self, config: Dict): param.transferDeviceId = config["device"] param.transferIoSize = config["io_size"] param.transferIoDirect = config.get("use_direct", False) - + param.transferStreamNumber = config.get("stream_number", 32) + param.transferBufferNumber = config.get("buffer_number", 512) # NOTE: compatible with legacy nfsstore lib - if hasattr(param, "storageCapacity"): - param.storageCapacity = config.get("storageCapacity", 0) - if hasattr(param, "recycleEnable"): - param.recycleEnable = True if config.get("recycleEnable", 0) == 1 else False + if hasattr(param, "storage_capacity"): + param.storageCapacity = config.get("storage_capacity", 0) + if hasattr(param, "recycle_enable"): + param.recycleEnable = ( + True if config.get("recycle_enable", 0) == 1 else False + ) if param.recycleEnable: - param.recycleThresholdRatio = config.get("recycleThresholdRatio", 0.7) + param.recycleThresholdRatio = config.get("recycle_threshold_ratio", 0.7) ret = self.store.Setup(param) if ret != 0: diff --git a/ucm/store/pcstore/pcstore_connector.py b/ucm/store/pcstore/pcstore_connector.py index e8486c3d..e9e0d46d 100644 --- a/ucm/store/pcstore/pcstore_connector.py +++ b/ucm/store/pcstore/pcstore_connector.py @@ -49,9 +49,10 @@ def __init__(self, config: Dict): if transfer_enable: param.transferDeviceId = config["device"] param.transferIoSize = config["io_size"] - param.transferIoDirect = True - param.transferStreamNumber = 8 - param.transferBufferNumber = 4096 + param.transferIoDirect = config.get("use_direct", False) + param.transferStreamNumber = config.get("stream_number", 8) + param.transferBufferNumber = config.get("buffer_number", 4096) + param.transferLocalRankSize = config.get("local_rank_size", 8) ret = self.store.Setup(param) if ret != 0: msg = f"Failed to initialize ucmpcstore, errcode: {ret}."