Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,14 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
)
self.block_size = self._vllm_config.cache_config.block_size
self.is_mla = self._vllm_config.model_config.is_deepseek_mla
self.is_dsa = False
self.kv_cache_dtype: torch.dtype = None

if current_platform.is_cuda_alike():
logger.info("CUDA device is available.")
torch_dev = torch
dev_name = "cuda"
elif current_platform.is_npu():
elif current_platform.device_type == "npu":
logger.info("NPU device is available.")
torch_dev = torch.npu
dev_name = "npu"
Expand Down Expand Up @@ -345,8 +346,17 @@ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"
self.kv_caches[layer_name] = attn_layer.kv_cache[
forward_context.virtual_engine
]
if self.kv_cache_dtype is None:
self.kv_cache_dtype = self.kv_caches[layer_name][0].dtype
# Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to
# (2, num_blocks, block_size, num_kv_heads, nope_dim/rope_dim).
# Currently, we treat it as GQA, and use is_dsa to mark it,
# which works but leads to space inefficiency.
# TODO: Optimize this to avoid unnecessary space usage.
sample_kv_layer = next(iter(self.kv_caches.values()))
if self.is_mla and len(sample_kv_layer) == 2:
self.is_mla = False
self.is_dsa = True
if self.kv_cache_dtype is None:
self.kv_cache_dtype = sample_kv_layer[0].dtype

@staticmethod
def _extract_layer_index(layer_name: str) -> Optional[int]:
Expand Down Expand Up @@ -491,7 +501,7 @@ def save_kv_layer(

def wait_for_save(self) -> None:

if self.is_mla and self.rank != 0:
if (self.is_mla or self.is_dsa) and self.rank != 0:
return

metadata = self._get_connector_metadata()
Expand Down
15 changes: 9 additions & 6 deletions ucm/store/nfsstore/nfsstore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,17 @@ def __init__(self, config: Dict):
param.transferDeviceId = config["device"]
param.transferIoSize = config["io_size"]
param.transferIoDirect = config.get("use_direct", False)

param.transferStreamNumber = config.get("stream_number", 32)
param.transferBufferNumber = config.get("buffer_number", 512)
# NOTE: compatible with legacy nfsstore lib
if hasattr(param, "storageCapacity"):
param.storageCapacity = config.get("storageCapacity", 0)
if hasattr(param, "recycleEnable"):
param.recycleEnable = True if config.get("recycleEnable", 0) == 1 else False
if hasattr(param, "storage_capacity"):
param.storageCapacity = config.get("storage_capacity", 0)
if hasattr(param, "recycle_enable"):
param.recycleEnable = (
True if config.get("recycle_enable", 0) == 1 else False
)
if param.recycleEnable:
param.recycleThresholdRatio = config.get("recycleThresholdRatio", 0.7)
param.recycleThresholdRatio = config.get("recycle_threshold_ratio", 0.7)

ret = self.store.Setup(param)
if ret != 0:
Expand Down
7 changes: 4 additions & 3 deletions ucm/store/pcstore/pcstore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ def __init__(self, config: Dict):
if transfer_enable:
param.transferDeviceId = config["device"]
param.transferIoSize = config["io_size"]
param.transferIoDirect = True
param.transferStreamNumber = 8
param.transferBufferNumber = 4096
param.transferIoDirect = config.get("use_direct", False)
param.transferStreamNumber = config.get("stream_number", 8)
param.transferBufferNumber = config.get("buffer_number", 4096)
param.transferLocalRankSize = config.get("local_rank_size", 8)
ret = self.store.Setup(param)
if ret != 0:
msg = f"Failed to initialize ucmpcstore, errcode: {ret}."
Expand Down