From 32ce65b8935ec0fca2e2cd86e14b8013fd1863dd Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 9 Apr 2026 16:07:46 +0800 Subject: [PATCH 1/9] use new mooncake API and enable zero-copy Signed-off-by: 0oshowero0 try: use new mooncake API and enable zero-copy Signed-off-by: 0oshowero0 update Signed-off-by: 0oshowero0 update Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 --- .../storage/clients/mooncake_client.py | 105 +++++++---- transfer_queue/utils/tensor_utils.py | 172 ++++++++++++++++++ 2 files changed, 239 insertions(+), 38 deletions(-) create mode 100644 transfer_queue/utils/tensor_utils.py diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 6ab610ee..75d3a8cd 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -23,13 +23,15 @@ from transfer_queue.storage.clients.base import TransferQueueStorageKVClient from transfer_queue.storage.clients.factory import StorageClientFactory +from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_continues_memory logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) MOONCAKE_STORE_IMPORTED: bool = True try: - from mooncake.store import MooncakeDistributedStore + from mooncake.store import MooncakeDistributedStore, ReplicateConfig + except ImportError: MOONCAKE_STORE_IMPORTED = False @@ -78,10 +80,9 @@ def __init__(self, config: dict[str, Any]): if not self.metadata_server.startswith("etcd://") and not self.metadata_server.endswith("/metadata"): self.metadata_server = self.metadata_server + "/metadata" - if self.metadata_server is None: - raise ValueError("Missing 'metadata_server' in config") - if self.master_server_address is None: - raise ValueError("Missing 'master_server_address' in config") + self.replica_config = ReplicateConfig() + # FIXME: hard_pin is not supported yet + # self.replica_config.with_hard_pin = True self._store = MooncakeDistributedStore() ret = self._store.setup( @@ -116,12 +117,8 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: for key, value in zip(keys, values, strict=True): if isinstance(value, torch.Tensor): - tensor = value.contiguous() - # TODO: use gpu direct rdma instead - if tensor.device.type == "cuda": - tensor = tensor.cpu() tensor_keys.append(key) - tensor_values.append(tensor) + tensor_values.append(value) else: non_tensor_keys.append(key) non_tensor_values.append(pickle.dumps(value)) @@ -139,7 +136,11 @@ def _batch_put_tensors(self, keys: list[str], tensors: list[Tensor]): batch_keys = keys[i : i + BATCH_SIZE_LIMIT] batch_tensors = tensors[i : i + BATCH_SIZE_LIMIT] - results = self._store.batch_put_tensor(batch_keys, batch_tensors) + batch_ptrs, batch_sizes = self._preprocess_tensors_for_put(batch_tensors) + batch_ptr_reduced, batch_sizes_reduced = merge_continues_memory(batch_ptrs, batch_sizes) + self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced) + + results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config) if not all(r == 0 for r in results): failed_indices = [j for j, r in enumerate(results) if r != 0] error_codes = [results[j] for j in failed_indices] @@ -147,30 +148,38 @@ def _batch_put_tensors(self, keys: list[str], tensors: list[Tensor]): f"batch_put_tensor failed for indices {failed_indices} with error codes: {error_codes}" ) + self._unregister_all_buffers(batch_ptr_reduced) + def _batch_put_bytes(self, keys: list[str], values: list[bytes]): for i in range(0, len(keys), BATCH_SIZE_LIMIT): batch_keys = keys[i : i + BATCH_SIZE_LIMIT] batch_values = values[i : i + BATCH_SIZE_LIMIT] - ret = self._store.put_batch(batch_keys, batch_values) + ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config) if ret != 0: raise RuntimeError(f"put_batch failed with error code: {ret}") - def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=None) -> list[Any]: + def get( + self, + keys: list[str], + shapes: Optional[list[Any]] = None, + dtypes: Optional[list[Any]] = None, + custom_backend_meta: Optional[list[str]] = None, + ) -> list[Any]: """Get multiple key-value pairs from MooncakeStore. Args: - keys (List[str]): Keys to fetch. - shapes (List[List[int]]): Expected tensor shapes (use [] for scalars). - dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data. - custom_backend_meta (List[str], optional): ... + keys: Keys to fetch. + shapes: Expected tensor shapes (use [] for scalars). + dtypes: Expected dtypes; use None for non-tensor data. + custom_backend_meta: Optional custom backend metadata. Returns: - List[Any]: Retrieved values in the same order as input keys. + Retrieved values in the same order as input keys. """ if shapes is None or dtypes is None: - raise ValueError("MooncakeStoreClient needs shapes and dtypes") + raise ValueError("MooncakeStoreClient needs shapes and dtypes for zero-copy transfer.") if not (len(keys) == len(shapes) == len(dtypes)): raise ValueError("Lengths of keys, shapes, dtypes must match") @@ -210,14 +219,25 @@ def _batch_get_tensors(self, keys: list[str], shapes: list, dtypes: list) -> lis batch_shapes = shapes[i : i + BATCH_SIZE_LIMIT] batch_dtypes = dtypes[i : i + BATCH_SIZE_LIMIT] - batch_results = self._store.batch_get_tensor(batch_keys) + batch_nbytes = get_nbytes(batch_dtypes, batch_shapes) + batch_buffer_tensors, batch_buffer_ptrs = allocate_empty_tensors(batch_dtypes, batch_shapes) - if len(batch_results) != len(batch_keys): - raise RuntimeError(f"batch_get_tensor returned {len(batch_results)} items, expected {len(batch_keys)}") + batch_ptrs = batch_buffer_ptrs - for j, (tensor, shape, dtype) in enumerate(zip(batch_results, batch_shapes, batch_dtypes, strict=True)): - if tensor is None: - raise RuntimeError(f"batch_get_tensor returned None for key '{batch_keys[j]}'") + self._register_all_buffers(batch_ptrs, batch_nbytes) + ret_codes = self._store.batch_get_into(batch_keys, batch_ptrs, batch_nbytes) + self._unregister_all_buffers(batch_ptrs) + + if len(ret_codes) != len(batch_keys): + raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}") + + # Check result codes and validate tensors + # Note: Positive values indicate success (bytes read), negative values indicate error + for j, (tensor, shape, dtype, ret_code) in enumerate( + zip(batch_buffer_tensors, batch_shapes, batch_dtypes, ret_codes, strict=True) + ): + if ret_code < 0: + raise RuntimeError(f"batch_get_into failed for key '{batch_keys[j]}' with error code: {ret_code}") if tensor.shape != torch.Size(shape): raise RuntimeError( f"Shape mismatch for key '{batch_keys[j]}': expected {shape}, got {tensor.shape}" @@ -243,26 +263,35 @@ def _batch_get_bytes(self, keys: list[str]) -> list[bytes]: def clear(self, keys: list[str], custom_backend_meta=None): """Deletes multiple keys from MooncakeStore. - Args: keys (List[str]): List of keys to remove. custom_backend_meta (List[Any], optional): ... """ - global_indexes_patterns = {key.split("@")[0] + "@.*" for key in keys} - for p in global_indexes_patterns: - ret = self._store.remove_by_regex(p, force=True) - if ret < 0: - logger.warning(f"remove failed for key '{p}' with error code: {ret}") - - # FIXME: controller returned BatchMeta may have mismatched fields in some case, preventing - # key-value based backends to accurately clear all existing keys.. - # for key in keys: - # ret = self._store.remove(key) - # if not (ret == 0 or ret == -704): - # logger.warning(f"remove failed for key '{key}' with error code: {ret}") + rets = self._store.batch_remove(keys, force=True) + for i, ret in enumerate(rets): + if not (ret == 0 or ret == -704): + logger.error(f"remove failed for key '{keys[i]}' with error code: {ret}") def close(self): """Closes MooncakeStore.""" if self._store: self._store.close() self._store = None + + @staticmethod + def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[Any], list[Any]]: + ptr_list = [] + size_list = [] + for t in values: + t = t.contiguous() + ptr_list.append(t.data_ptr()) + size_list.append(t.nbytes) + return ptr_list, size_list + + def _register_all_buffers(self, ptrs, sizes): + for ptr, size in zip(ptrs, sizes, strict=False): + self._store.register_buffer(ptr, size) + + def _unregister_all_buffers(self, ptrs): + for ptr in ptrs: + self._store.unregister_buffer(ptr) diff --git a/transfer_queue/utils/tensor_utils.py b/transfer_queue/utils/tensor_utils.py new file mode 100644 index 00000000..95a12c98 --- /dev/null +++ b/transfer_queue/utils/tensor_utils.py @@ -0,0 +1,172 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import operator +import os +from functools import reduce + +import torch +from torch import Tensor + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + + +def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tuple[list[Tensor], list[int]]: + """Allocate empty tensors, grouping same dtypes into shared memory blocks. + + Instead of allocating each tensor separately, this function groups tensors + by their dtype and allocates one large contiguous memory block per dtype. + Each tensor is then created as a view into this shared memory. + + Args: + dtypes: List of torch dtypes for each tensor. + shapes: List of shapes (tuples) for each tensor. + + Returns: + A tuple containing: + - List of tensors sharing memory within their dtype groups. + - List of memory pointers (data_ptr) for each tensor. + + Example: + >>> dtypes = [torch.float32, torch.float32, torch.int32, torch.float32] + >>> shapes = [(10,), (20,), (5,), (15,)] + >>> tensors, ptrs = allocate_empty_tensors(dtypes, shapes) + >>> # tensors[0], [1], [3] share the same dtype and memory block + """ + assert len(dtypes) == len(shapes), "dtypes and shapes must have the same length" + + if len(dtypes) == 0: + return [], [] + + # Group indices by dtype + dtype_groups: dict[torch.dtype, list[int]] = {} + for i, dtype in enumerate(dtypes): + if dtype not in dtype_groups: + dtype_groups[dtype] = [] + dtype_groups[dtype].append(i) + + tensor_list = [torch.empty(()) for _ in range(len(dtypes))] + ptr_list = [0] * len(dtypes) + + # For each dtype group, allocate one big tensor and create views + for dtype, indices in dtype_groups.items(): + # Calculate total number of elements needed for this dtype + total_elements = 0 + shape_info = [] # Store (index, shape, num_elements, offset) + + for idx in indices: + shape = shapes[idx] + num_elements = reduce(operator.mul, shape) + shape_info.append((idx, shape, num_elements, total_elements)) + total_elements += num_elements + + # Allocate one big contiguous memory block for this dtype + big_tensor = torch.empty(total_elements, dtype=dtype) + + # Create views into the big tensor for each small tensor + for idx, shape, num_elements, offset in shape_info: + # Use as_strided to create a view with the correct shape + small_tensor = big_tensor.as_strided(size=shape, stride=compute_stride(shape), storage_offset=offset) + tensor_list[idx] = small_tensor + ptr_list[idx] = small_tensor.data_ptr() + + return tensor_list, ptr_list + + +def compute_stride(shape: tuple[int, ...]) -> tuple[int, ...]: + """Compute stride for a contiguous row-major (C-style) tensor. + + Args: + shape: The shape of the tensor. + + Returns: + Stride tuple for contiguous storage. + + Example: + >>> compute_stride((2, 3, 4)) + (12, 4, 1) + """ + stride = [] + cumulative = 1 + # Iterate from last dimension to first + for dim in reversed(shape): + stride.append(cumulative) + cumulative *= dim + return tuple(reversed(stride)) + + +def get_nbytes(dtypes, shapes) -> list[int]: + assert len(dtypes) == len(shapes) + nbytes = [] + for i in range(len(dtypes)): + elem_size = torch.tensor([], dtype=dtypes[i]).element_size() + numel = reduce(operator.mul, shapes[i]) + nbytes.append(elem_size * numel) + + return nbytes + + +def merge_continues_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int], list[int]]: + """Merge continuous memory regions to reduce register_buffer overhead + + Args: + ptrs: List of memory pointers (starting addresses). + sizes: List of memory region sizes corresponding to each pointer. + + Returns: + A tuple of (merged_ptrs, merged_sizes) where continuous regions + have been merged into single regions. + + Example: + >>> merge_continues_memory([0, 10, 30], [10, 20, 10]) + ([0, 30], [30, 10]) + + >>> merge_continues_memory([0, 5, 20], [5, 5, 10]) + ([0, 20], [10, 10]) + """ + if not ptrs or not sizes: + return [], [] + + if len(ptrs) != len(sizes): + raise ValueError("ptrs and sizes must have the same length") + + # Create list of (ptr, size) pairs and sort by pointer address + regions = sorted(zip(ptrs, sizes, strict=False), key=lambda x: x[0]) + + merged_ptrs = [] + merged_sizes = [] + + # Initialize with the first region + current_ptr, current_size = regions[0] + + for ptr, size in regions[1:]: + # Check if current region is contiguous with the next one + # A region is contiguous if: ptr == current_ptr + current_size + if ptr == current_ptr + current_size: + # Merge: extend the current region + current_size += size + else: + # Not contiguous: save the current region and start a new one + merged_ptrs.append(current_ptr) + merged_sizes.append(current_size) + current_ptr, current_size = ptr, size + + # Add the last region + merged_ptrs.append(current_ptr) + merged_sizes.append(current_size) + + return merged_ptrs, merged_sizes From 17b0712b062ef130c0560c2edc3892ca40331610 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 10 Apr 2026 10:57:23 +0800 Subject: [PATCH 2/9] multi-thread support for concurrent data preprocess & transfer Signed-off-by: 0oshowero0 --- .../storage/clients/mooncake_client.py | 170 +++++++++--------- 1 file changed, 86 insertions(+), 84 deletions(-) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 75d3a8cd..95ef3a7a 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -16,6 +16,7 @@ import logging import os import pickle +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Optional import torch @@ -35,7 +36,8 @@ except ImportError: MOONCAKE_STORE_IMPORTED = False -BATCH_SIZE_LIMIT: int = 500 +BATCH_SIZE_LIMIT: int = 200 +MAX_WORKER_THREADS = 4 @StorageClientFactory.register("MooncakeStoreClient") @@ -81,7 +83,7 @@ def __init__(self, config: dict[str, Any]): self.metadata_server = self.metadata_server + "/metadata" self.replica_config = ReplicateConfig() - # FIXME: hard_pin is not supported yet + # FIXME: hard_pin support # self.replica_config.with_hard_pin = True self._store = MooncakeDistributedStore() @@ -97,7 +99,7 @@ def __init__(self, config: dict[str, Any]): if ret != 0: raise RuntimeError(f"Mooncake store setup failed with error code: {ret}") - def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: + def put(self, keys: list[str], values: list[Any]) -> None: """Stores multiple key-value pairs to MooncakeStore. Args: @@ -121,25 +123,33 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: tensor_values.append(value) else: non_tensor_keys.append(key) - non_tensor_values.append(pickle.dumps(value)) + non_tensor_values.append(value) - if tensor_keys: - self._batch_put_tensors(tensor_keys, tensor_values) + futures = [] + with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor: + for i in range(0, len(tensor_keys), BATCH_SIZE_LIMIT): + batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT] + batch_tensors = tensor_values[i : i + BATCH_SIZE_LIMIT] + futures.append(executor.submit(self._put_tensors_thread_worker, batch_keys, batch_tensors)) - if non_tensor_keys: - self._batch_put_bytes(non_tensor_keys, non_tensor_values) + for i in range(0, len(non_tensor_keys), BATCH_SIZE_LIMIT): + batch_keys = non_tensor_keys[i : i + BATCH_SIZE_LIMIT] + batch_values = non_tensor_values[i : i + BATCH_SIZE_LIMIT] + futures.append(executor.submit(self._put_bytes_thread_worker, batch_keys, batch_values)) + + for future in as_completed(futures): + future.result() return None - def _batch_put_tensors(self, keys: list[str], tensors: list[Tensor]): - for i in range(0, len(keys), BATCH_SIZE_LIMIT): - batch_keys = keys[i : i + BATCH_SIZE_LIMIT] - batch_tensors = tensors[i : i + BATCH_SIZE_LIMIT] + def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]): + """Worker thread for putting batch of tensors to MooncakeStore.""" - batch_ptrs, batch_sizes = self._preprocess_tensors_for_put(batch_tensors) - batch_ptr_reduced, batch_sizes_reduced = merge_continues_memory(batch_ptrs, batch_sizes) - self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced) + batch_ptrs, batch_sizes, contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors) + batch_ptr_reduced, batch_sizes_reduced = merge_continues_memory(batch_ptrs, batch_sizes) + self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced) + try: results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config) if not all(r == 0 for r in results): failed_indices = [j for j, r in enumerate(results) if r != 0] @@ -147,17 +157,17 @@ def _batch_put_tensors(self, keys: list[str], tensors: list[Tensor]): raise RuntimeError( f"batch_put_tensor failed for indices {failed_indices} with error codes: {error_codes}" ) - + finally: self._unregister_all_buffers(batch_ptr_reduced) - def _batch_put_bytes(self, keys: list[str], values: list[bytes]): - for i in range(0, len(keys), BATCH_SIZE_LIMIT): - batch_keys = keys[i : i + BATCH_SIZE_LIMIT] - batch_values = values[i : i + BATCH_SIZE_LIMIT] + def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[bytes]): + """Worker thread for putting batch of non-tensors to MooncakeStore.""" - ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config) - if ret != 0: - raise RuntimeError(f"put_batch failed with error code: {ret}") + batch_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values] + + ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config) + if ret != 0: + raise RuntimeError(f"put_batch failed with error code: {ret}") def get( self, @@ -194,71 +204,61 @@ def get( results = [None] * len(keys) - if tensor_indices: - tensor_keys = [keys[i] for i in tensor_indices] - tensor_shapes = [shapes[i] for i in tensor_indices] - tensor_dtypes = [dtypes[i] for i in tensor_indices] - tensor_results = self._batch_get_tensors(tensor_keys, tensor_shapes, tensor_dtypes) - # TODO: optimize these for loops - for idx, tensor in zip(tensor_indices, tensor_results, strict=True): - results[idx] = tensor - - if non_tensor_indices: - non_tensor_keys = [keys[i] for i in non_tensor_indices] - non_tensor_results = self._batch_get_bytes(non_tensor_keys) - for idx, data in zip(non_tensor_indices, non_tensor_results, strict=True): - results[idx] = pickle.loads(data) - - return results - - def _batch_get_tensors(self, keys: list[str], shapes: list, dtypes: list) -> list[Tensor]: - tensors = [None] * len(keys) + futures = [] + with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor: + for i in range(0, len(tensor_indices), BATCH_SIZE_LIMIT): + batch_indexes = tensor_indices[i : i + BATCH_SIZE_LIMIT] + batch_keys = [keys[i] for i in batch_indexes] + batch_shapes = [shapes[i] for i in batch_indexes] + batch_dtypes = [dtypes[i] for i in batch_indexes] + futures.append( + executor.submit( + self._get_tensors_thread_worker, batch_keys, batch_shapes, batch_dtypes, batch_indexes + ) + ) - for i in range(0, len(keys), BATCH_SIZE_LIMIT): - batch_keys = keys[i : i + BATCH_SIZE_LIMIT] - batch_shapes = shapes[i : i + BATCH_SIZE_LIMIT] - batch_dtypes = dtypes[i : i + BATCH_SIZE_LIMIT] + for i in range(0, len(non_tensor_indices), BATCH_SIZE_LIMIT): + batch_indexes = non_tensor_indices[i : i + BATCH_SIZE_LIMIT] + batch_keys = [keys[i] for i in batch_indexes] + futures.append(executor.submit(self._get_bytes_thread_worker, batch_keys, batch_indexes)) - batch_nbytes = get_nbytes(batch_dtypes, batch_shapes) - batch_buffer_tensors, batch_buffer_ptrs = allocate_empty_tensors(batch_dtypes, batch_shapes) + for future in as_completed(futures): + retrieved_values, batch_indexes = future.result() + for idx, val in zip(batch_indexes, retrieved_values, strict=True): + results[idx] = val - batch_ptrs = batch_buffer_ptrs + return results - self._register_all_buffers(batch_ptrs, batch_nbytes) - ret_codes = self._store.batch_get_into(batch_keys, batch_ptrs, batch_nbytes) - self._unregister_all_buffers(batch_ptrs) + def _get_tensors_thread_worker( + self, batch_keys: list[str], batch_shapes: list[tuple], batch_dtypes: list[torch.dtype], indexes: list[int] + ) -> tuple[list[Tensor], list[int]]: + batch_nbytes = get_nbytes(batch_dtypes, batch_shapes) + batch_buffer_tensors, batch_buffer_ptrs = allocate_empty_tensors(batch_dtypes, batch_shapes) + self._register_all_buffers(batch_buffer_ptrs, batch_nbytes) + try: + ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes) if len(ret_codes) != len(batch_keys): raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}") + for i, ret in enumerate(ret_codes): + if ret < 0: + raise RuntimeError(f"batch_get_into failed for key `{batch_keys[i]}` with error code: {ret}") + finally: + self._unregister_all_buffers(batch_buffer_ptrs) - # Check result codes and validate tensors - # Note: Positive values indicate success (bytes read), negative values indicate error - for j, (tensor, shape, dtype, ret_code) in enumerate( - zip(batch_buffer_tensors, batch_shapes, batch_dtypes, ret_codes, strict=True) - ): - if ret_code < 0: - raise RuntimeError(f"batch_get_into failed for key '{batch_keys[j]}' with error code: {ret_code}") - if tensor.shape != torch.Size(shape): - raise RuntimeError( - f"Shape mismatch for key '{batch_keys[j]}': expected {shape}, got {tensor.shape}" - ) - if tensor.dtype != dtype: - raise RuntimeError( - f"Dtype mismatch for key '{batch_keys[j]}': expected {dtype}, got {tensor.dtype}" - ) - tensors[i + j] = tensor - - return tensors + return batch_buffer_tensors, indexes - def _batch_get_bytes(self, keys: list[str]) -> list[bytes]: + def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> tuple[list[Any], list[int]]: results = [] - for i in range(0, len(keys), BATCH_SIZE_LIMIT): - batch_keys = keys[i : i + BATCH_SIZE_LIMIT] - batch_results = self._store.get_batch(batch_keys) - if len(batch_results) != len(batch_keys): - raise RuntimeError(f"get_batch returned {len(batch_results)} items, expected {len(batch_keys)}") - results.extend(batch_results) - return results + + batch_results = self._store.get_batch(batch_keys) + if len(batch_results) != len(batch_keys): + raise RuntimeError(f"get_batch returned {len(batch_results)} items, expected {len(batch_keys)}") + + batch_results = [pickle.loads(result) for result in batch_results] + results.extend(batch_results) + + return results, indexes def clear(self, keys: list[str], custom_backend_meta=None): """Deletes multiple keys from MooncakeStore. @@ -267,10 +267,10 @@ def clear(self, keys: list[str], custom_backend_meta=None): keys (List[str]): List of keys to remove. custom_backend_meta (List[Any], optional): ... """ - rets = self._store.batch_remove(keys, force=True) - for i, ret in enumerate(rets): + ret_codes = self._store.batch_remove(keys, force=True) + for i, ret in enumerate(ret_codes): if not (ret == 0 or ret == -704): - logger.error(f"remove failed for key '{keys[i]}' with error code: {ret}") + logger.error(f"remove failed for key `{keys[i]}` with error code: {ret}") def close(self): """Closes MooncakeStore.""" @@ -279,17 +279,19 @@ def close(self): self._store = None @staticmethod - def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[Any], list[Any]]: + def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[Any], list[Any], list[Tensor]]: ptr_list = [] size_list = [] + tensor_list = [] # hold reference for the contiguous tensor for t in values: t = t.contiguous() + tensor_list.append(t) ptr_list.append(t.data_ptr()) size_list.append(t.nbytes) - return ptr_list, size_list + return ptr_list, size_list, tensor_list def _register_all_buffers(self, ptrs, sizes): - for ptr, size in zip(ptrs, sizes, strict=False): + for ptr, size in zip(ptrs, sizes, strict=True): self._store.register_buffer(ptr, size) def _unregister_all_buffers(self, ptrs): From fccb1fef9363e39a88ec9ad32dec97e786b7b63e Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 10 Apr 2026 10:59:01 +0800 Subject: [PATCH 3/9] fix docstring Signed-off-by: 0oshowero0 --- transfer_queue/utils/tensor_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transfer_queue/utils/tensor_utils.py b/transfer_queue/utils/tensor_utils.py index 95a12c98..f47377ff 100644 --- a/transfer_queue/utils/tensor_utils.py +++ b/transfer_queue/utils/tensor_utils.py @@ -110,6 +110,7 @@ def compute_stride(shape: tuple[int, ...]) -> tuple[int, ...]: def get_nbytes(dtypes, shapes) -> list[int]: + """Calculate number of bytes according to tensor dtypes and shapes.""" assert len(dtypes) == len(shapes) nbytes = [] for i in range(len(dtypes)): From 5446a819b5a693a1e7613c05bbc559d62c534a7e Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 11 Apr 2026 15:06:22 +0800 Subject: [PATCH 4/9] update Signed-off-by: 0oshowero0 --- tests/e2e/test_kv_interface_e2e.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 8b171717..83f8b7eb 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -212,7 +212,6 @@ def test_kv_put_with_dict_fields(self, controller, tq_api): expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed assert_tensor_equal(retrieved["data"], expected) - # delete the key (MooncakeStore does not support updating existing key, so we need to clear it before next test) tq_api.kv_clear(keys=key, partition_id=partition_id) def test_kv_put_with_tensordict_fields(self, controller, tq_api): From 108a09e76ed9a46704ee2ef488bcd2694b17b135 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 17 Apr 2026 13:46:54 +0800 Subject: [PATCH 5/9] support hard_pin Signed-off-by: 0oshowero0 --- transfer_queue/storage/clients/mooncake_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 95ef3a7a..57545345 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -83,8 +83,7 @@ def __init__(self, config: dict[str, Any]): self.metadata_server = self.metadata_server + "/metadata" self.replica_config = ReplicateConfig() - # FIXME: hard_pin support - # self.replica_config.with_hard_pin = True + self.replica_config.with_hard_pin = True self._store = MooncakeDistributedStore() ret = self._store.setup( From c2cd296bf6b7fb7e8d5142021206e73953b6f339 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 25 Apr 2026 10:35:24 +0800 Subject: [PATCH 6/9] update mooncake version Signed-off-by: 0oshowero0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3a067a18..0bbb2f0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,7 @@ yuanrong = [ "openyuanrong-datasystem" ] mooncake = [ - "mooncake-transfer-engine==0.3.10.post1" + "mooncake-transfer-engine==0.3.10.post2" ] # If you need to mimic `package_dir={'': '.'}`: From c4c4198063423798064fa225c19ac8fc385c4a84 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 25 Apr 2026 12:30:53 +0800 Subject: [PATCH 7/9] fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 --- tests/test_tensor_utils.py | 196 ++++++++++++++++++ .../storage/clients/mooncake_client.py | 25 ++- transfer_queue/utils/tensor_utils.py | 39 ++-- 3 files changed, 236 insertions(+), 24 deletions(-) create mode 100644 tests/test_tensor_utils.py diff --git a/tests/test_tensor_utils.py b/tests/test_tensor_utils.py new file mode 100644 index 00000000..5d534938 --- /dev/null +++ b/tests/test_tensor_utils.py @@ -0,0 +1,196 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for transfer_queue.utils.tensor_utils.""" + +import pytest +import torch + +from transfer_queue.utils.tensor_utils import ( + allocate_empty_tensors, + compute_stride, + get_nbytes, + merge_contiguous_memory, +) + + +class TestComputeStride: + """Tests for compute_stride.""" + + def test_3d(self): + assert compute_stride((2, 3, 4)) == (12, 4, 1) + + def test_1d(self): + assert compute_stride((5,)) == (1,) + + def test_scalar(self): + assert compute_stride(()) == () + + def test_2d(self): + assert compute_stride((3, 5)) == (5, 1) + + +class TestGetNbytes: + """Tests for get_nbytes.""" + + def test_basic(self): + dtypes = [torch.float32, torch.int32] + shapes = [(2, 3), (4,)] + result = get_nbytes(dtypes, shapes) + assert result == [2 * 3 * 4, 4 * 4] # float32=4, int32=4 + + def test_scalar(self): + dtypes = [torch.float64] + shapes = [()] + result = get_nbytes(dtypes, shapes) + assert result == [8] # scalar = 1 element + + def test_list_shape(self): + dtypes = [torch.float32] + shapes = [[]] # list instead of tuple + result = get_nbytes(dtypes, shapes) + assert result == [4] + + def test_mixed_dtypes(self): + dtypes = [torch.float16, torch.float32, torch.int64] + shapes = [(10,), (10,), (10,)] + result = get_nbytes(dtypes, shapes) + assert result == [10 * 2, 10 * 4, 10 * 8] + + +class TestAllocateEmptyTensors: + """Tests for allocate_empty_tensors.""" + + def test_basic(self): + dtypes = [torch.float32, torch.float32, torch.int32] + shapes = [(2, 3), (4,), (5,)] + tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) + + assert len(tensors) == 3 + assert len(ptrs) == 3 + assert len(region_ptrs) == 2 # float32 group + int32 group + assert len(region_sizes) == 2 + + # Same dtype tensors share the same underlying storage + assert tensors[0].untyped_storage().data_ptr() == region_ptrs[0] + assert tensors[1].untyped_storage().data_ptr() == region_ptrs[0] + assert tensors[2].untyped_storage().data_ptr() == region_ptrs[1] + + # Shapes are correct + assert list(tensors[0].shape) == [2, 3] + assert list(tensors[1].shape) == [4] + assert list(tensors[2].shape) == [5] + + def test_scalar(self): + dtypes = [torch.float32, torch.int32] + shapes = [(), ()] + tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) + + assert len(tensors) == 2 + assert tensors[0].numel() == 1 + assert tensors[1].numel() == 1 + assert len(region_ptrs) == 2 + + def test_empty(self): + result = allocate_empty_tensors([], []) + assert result == ([], [], [], []) + + def test_regions_complex(self): + """Mixed dtypes and shapes: verify region counts, sizes, and per-tensor offsets.""" + dtypes = [ + torch.float32, # group 0: (2, 3) -> 6 elements + torch.int32, # group 1: (4,) -> 4 elements + torch.float32, # group 0: scalar -> 1 element + torch.float64, # group 2: (2, 2) -> 4 elements + torch.int32, # group 1: (3, 2) -> 6 elements + ] + shapes = [(2, 3), (4,), (), (2, 2), (3, 2)] + tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) + + # 3 dtype groups in insertion order: float32, int32, float64 + assert len(region_ptrs) == 3 + assert len(region_sizes) == 3 + assert len(set(region_ptrs)) == 3 # distinct allocations + + # float32 region: 6 + 1 = 7 elements * 4 bytes = 28 bytes + assert region_sizes[0] == 7 * 4 + # int32 region: 4 + 6 = 10 elements * 4 bytes = 40 bytes + assert region_sizes[1] == 10 * 4 + # float64 region: 4 elements * 8 bytes = 32 bytes + assert region_sizes[2] == 4 * 8 + + # Per-tensor ptrs must lie inside their respective regions + # tensor 0 (float32, shape (2,3), offset 0) + assert ptrs[0] == region_ptrs[0] + # tensor 1 (int32, shape (4,), offset 0) + assert ptrs[1] == region_ptrs[1] + # tensor 2 (float32, scalar, offset 6) + assert ptrs[2] == region_ptrs[0] + 6 * 4 + # tensor 3 (float64, shape (2,2), offset 0) + assert ptrs[3] == region_ptrs[2] + # tensor 4 (int32, shape (3,2), offset 4) + assert ptrs[4] == region_ptrs[1] + 4 * 4 + + +class TestMergeContiguousMemory: + """Tests for merge_contiguous_memory.""" + + def test_basic_merge(self): + ptrs = [0, 10, 30] + sizes = [10, 20, 10] + merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes) + # 0+10=10 (contiguous with 10), 10+20=30 (contiguous with 30) -> all merge into [0] + assert merged_ptrs == [0] + assert merged_sizes == [40] + + def test_no_contiguous(self): + ptrs = [0, 100, 200] + sizes = [50, 50, 50] + merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes) + assert merged_ptrs == [0, 100, 200] + assert merged_sizes == [50, 50, 50] + + def test_unsorted_input(self): + ptrs = [100, 0, 50] + sizes = [50, 50, 50] + merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes) + # After sorting: 0, 50, 100; all contiguous -> merge into [0] + assert merged_ptrs == [0] + assert merged_sizes == [150] + + def test_single_region(self): + ptrs = [10] + sizes = [100] + merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes) + assert merged_ptrs == [10] + assert merged_sizes == [100] + + def test_empty(self): + assert merge_contiguous_memory([], []) == ([], []) + + def test_mismatched_lengths_both_empty_not_triggered(self): + # If one is empty and other is not, should raise ValueError + with pytest.raises(ValueError, match="ptrs and sizes must have the same length"): + merge_contiguous_memory([], [10]) + + with pytest.raises(ValueError, match="ptrs and sizes must have the same length"): + merge_contiguous_memory([0], []) + + def test_three_continuous(self): + ptrs = [0, 10, 20] + sizes = [10, 10, 10] + merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes) + assert merged_ptrs == [0] + assert merged_sizes == [30] diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 57545345..aea858f9 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -24,7 +24,7 @@ from transfer_queue.storage.clients.base import TransferQueueStorageKVClient from transfer_queue.storage.clients.factory import StorageClientFactory -from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_continues_memory +from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_contiguous_memory logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -144,8 +144,8 @@ def put(self, keys: list[str], values: list[Any]) -> None: def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]): """Worker thread for putting batch of tensors to MooncakeStore.""" - batch_ptrs, batch_sizes, contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors) - batch_ptr_reduced, batch_sizes_reduced = merge_continues_memory(batch_ptrs, batch_sizes) + batch_ptrs, batch_sizes, _contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors) + batch_ptr_reduced, batch_sizes_reduced = merge_contiguous_memory(batch_ptrs, batch_sizes) self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced) try: @@ -154,19 +154,19 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[ failed_indices = [j for j, r in enumerate(results) if r != 0] error_codes = [results[j] for j in failed_indices] raise RuntimeError( - f"batch_put_tensor failed for indices {failed_indices} with error codes: {error_codes}" + f"batch_upsert_from failed for indices {failed_indices} with error codes: {error_codes}" ) finally: self._unregister_all_buffers(batch_ptr_reduced) - def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[bytes]): + def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]): """Worker thread for putting batch of non-tensors to MooncakeStore.""" batch_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values] ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config) if ret != 0: - raise RuntimeError(f"put_batch failed with error code: {ret}") + raise RuntimeError(f"upsert_batch failed with error code: {ret}") def get( self, @@ -232,9 +232,11 @@ def _get_tensors_thread_worker( self, batch_keys: list[str], batch_shapes: list[tuple], batch_dtypes: list[torch.dtype], indexes: list[int] ) -> tuple[list[Tensor], list[int]]: batch_nbytes = get_nbytes(batch_dtypes, batch_shapes) - batch_buffer_tensors, batch_buffer_ptrs = allocate_empty_tensors(batch_dtypes, batch_shapes) + batch_buffer_tensors, batch_buffer_ptrs, region_ptrs, region_sizes = allocate_empty_tensors( + batch_dtypes, batch_shapes + ) - self._register_all_buffers(batch_buffer_ptrs, batch_nbytes) + self._register_all_buffers(region_ptrs, region_sizes) try: ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes) if len(ret_codes) != len(batch_keys): @@ -243,7 +245,7 @@ def _get_tensors_thread_worker( if ret < 0: raise RuntimeError(f"batch_get_into failed for key `{batch_keys[i]}` with error code: {ret}") finally: - self._unregister_all_buffers(batch_buffer_ptrs) + self._unregister_all_buffers(region_ptrs) return batch_buffer_tensors, indexes @@ -283,6 +285,11 @@ def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[Any], list[A size_list = [] tensor_list = [] # hold reference for the contiguous tensor for t in values: + # TODO: support gpu direct rdma and use different data paths. + # For GPU, it's more reasonable to perform data copy since + # The register overhead is much higher than CPU + if t.device.type == "cuda": + t = t.cpu() t = t.contiguous() tensor_list.append(t) ptr_list.append(t.data_ptr()) diff --git a/transfer_queue/utils/tensor_utils.py b/transfer_queue/utils/tensor_utils.py index f47377ff..b3b8fa06 100644 --- a/transfer_queue/utils/tensor_utils.py +++ b/transfer_queue/utils/tensor_utils.py @@ -25,7 +25,9 @@ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) -def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tuple[list[Tensor], list[int]]: +def allocate_empty_tensors( + dtypes: list[torch.dtype], shapes: list[tuple] +) -> tuple[list[Tensor], list[int], list[int], list[int]]: """Allocate empty tensors, grouping same dtypes into shared memory blocks. Instead of allocating each tensor separately, this function groups tensors @@ -40,17 +42,19 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu A tuple containing: - List of tensors sharing memory within their dtype groups. - List of memory pointers (data_ptr) for each tensor. + - List of base pointers for each allocated memory region (one per dtype). + - List of total bytes for each allocated memory region (one per dtype). Example: >>> dtypes = [torch.float32, torch.float32, torch.int32, torch.float32] >>> shapes = [(10,), (20,), (5,), (15,)] - >>> tensors, ptrs = allocate_empty_tensors(dtypes, shapes) + >>> tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) >>> # tensors[0], [1], [3] share the same dtype and memory block """ assert len(dtypes) == len(shapes), "dtypes and shapes must have the same length" if len(dtypes) == 0: - return [], [] + return [], [], [], [] # Group indices by dtype dtype_groups: dict[torch.dtype, list[int]] = {} @@ -61,6 +65,8 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu tensor_list = [torch.empty(()) for _ in range(len(dtypes))] ptr_list = [0] * len(dtypes) + region_ptrs: list[int] = [] + region_sizes: list[int] = [] # For each dtype group, allocate one big tensor and create views for dtype, indices in dtype_groups.items(): @@ -69,13 +75,15 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu shape_info = [] # Store (index, shape, num_elements, offset) for idx in indices: - shape = shapes[idx] - num_elements = reduce(operator.mul, shape) + shape = tuple(shapes[idx]) + num_elements = reduce(operator.mul, shape, 1) shape_info.append((idx, shape, num_elements, total_elements)) total_elements += num_elements # Allocate one big contiguous memory block for this dtype big_tensor = torch.empty(total_elements, dtype=dtype) + region_ptrs.append(big_tensor.data_ptr()) + region_sizes.append(big_tensor.nbytes) # Create views into the big tensor for each small tensor for idx, shape, num_elements, offset in shape_info: @@ -84,7 +92,7 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu tensor_list[idx] = small_tensor ptr_list[idx] = small_tensor.data_ptr() - return tensor_list, ptr_list + return tensor_list, ptr_list, region_ptrs, region_sizes def compute_stride(shape: tuple[int, ...]) -> tuple[int, ...]: @@ -115,36 +123,37 @@ def get_nbytes(dtypes, shapes) -> list[int]: nbytes = [] for i in range(len(dtypes)): elem_size = torch.tensor([], dtype=dtypes[i]).element_size() - numel = reduce(operator.mul, shapes[i]) + shape = tuple(shapes[i]) + numel = reduce(operator.mul, shape, 1) nbytes.append(elem_size * numel) return nbytes -def merge_continues_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int], list[int]]: - """Merge continuous memory regions to reduce register_buffer overhead +def merge_contiguous_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int], list[int]]: + """Merge contiguous memory regions to reduce register_buffer overhead Args: ptrs: List of memory pointers (starting addresses). sizes: List of memory region sizes corresponding to each pointer. Returns: - A tuple of (merged_ptrs, merged_sizes) where continuous regions + A tuple of (merged_ptrs, merged_sizes) where contiguous regions have been merged into single regions. Example: - >>> merge_continues_memory([0, 10, 30], [10, 20, 10]) + >>> merge_contiguous_memory([0, 10, 30], [10, 20, 10]) ([0, 30], [30, 10]) - >>> merge_continues_memory([0, 5, 20], [5, 5, 10]) + >>> merge_contiguous_memory([0, 5, 20], [5, 5, 10]) ([0, 20], [10, 10]) """ - if not ptrs or not sizes: - return [], [] - if len(ptrs) != len(sizes): raise ValueError("ptrs and sizes must have the same length") + if not ptrs: + return [], [] + # Create list of (ptr, size) pairs and sort by pointer address regions = sorted(zip(ptrs, sizes, strict=False), key=lambda x: x[0]) From 652d89fd85575bddb7e1aaba94bac9ab7524fa22 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 28 Apr 2026 18:30:16 +0800 Subject: [PATCH 8/9] fix error for getting None value Signed-off-by: 0oshowero0 --- transfer_queue/storage/clients/mooncake_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index aea858f9..e0411608 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -256,7 +256,7 @@ def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> if len(batch_results) != len(batch_keys): raise RuntimeError(f"get_batch returned {len(batch_results)} items, expected {len(batch_keys)}") - batch_results = [pickle.loads(result) for result in batch_results] + batch_results = [pickle.loads(result) if result != b"" else None for result in batch_results] results.extend(batch_results) return results, indexes From c74d687f78539d0cfbdb79ca8d96dec30fe0ec8c Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 28 Apr 2026 20:11:39 +0800 Subject: [PATCH 9/9] fix type comments Signed-off-by: 0oshowero0 --- .../storage/clients/mooncake_client.py | 12 ++++---- .../storage/clients/yuanrong_client.py | 30 +++++++++++-------- transfer_queue/storage/managers/base.py | 23 +++++++------- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index e0411608..4f4f8641 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -141,7 +141,7 @@ def put(self, keys: list[str], values: list[Any]) -> None: return None - def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]): + def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]) -> None: """Worker thread for putting batch of tensors to MooncakeStore.""" batch_ptrs, batch_sizes, _contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors) @@ -261,7 +261,7 @@ def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> return results, indexes - def clear(self, keys: list[str], custom_backend_meta=None): + def clear(self, keys: list[str], custom_backend_meta: Optional[list[Any]] = None) -> None: """Deletes multiple keys from MooncakeStore. Args: @@ -280,10 +280,10 @@ def close(self): self._store = None @staticmethod - def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[Any], list[Any], list[Tensor]]: - ptr_list = [] - size_list = [] - tensor_list = [] # hold reference for the contiguous tensor + def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[int], list[int], list[Tensor]]: + ptr_list: list[int] = [] + size_list: list[int] = [] + tensor_list: list[Tensor] = [] # hold reference for the contiguous tensor for t in values: # TODO: support gpu direct rdma and use different data paths. # For GPU, it's more reasonable to perform data copy since diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index fccd9a9b..c1622392 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -57,7 +57,7 @@ def supports_put(self, value: Any) -> bool: """Check if this strategy can store the given value.""" @abstractmethod - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> None: """Store key-value pairs using this strategy.""" @abstractmethod @@ -73,7 +73,7 @@ def supports_clear(self, strategy_tag: Any) -> bool: """Check if this strategy owns data identified by metadata.""" @abstractmethod - def clear(self, keys: list[str]): + def clear(self, keys: list[str]) -> None: """Delete keys from storage.""" @@ -131,7 +131,7 @@ def supports_put(self, value: Any) -> bool: # Only contiguous NPU tensors are supported by this adapter. return value.is_contiguous() - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> None: """Store NPU tensors in batches; deletes before overwrite.""" for i in range(0, len(keys), self.KEYS_LIMIT): batch_keys = keys[i : i + self.KEYS_LIMIT] @@ -169,14 +169,14 @@ def supports_clear(self, strategy_tag: str) -> bool: """Matches 'DsTensorClient' strategy tag.""" return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() - def clear(self, keys: list[str]): + def clear(self, keys: list[str]) -> None: """Delete NPU tensor keys in batches.""" for i in range(0, len(keys), self.KEYS_LIMIT): batch = keys[i : i + self.KEYS_LIMIT] # Todo(dpj): Test call clear when no (key,value) put in ds self._ds_client.delete(batch) - def _create_empty_npu_tensorlist(self, shapes: list, dtypes: list): + def _create_empty_npu_tensorlist(self, shapes: list[Any], dtypes: list[Any]) -> list[Tensor]: """ Create a list of empty NPU tensors with given shapes and dtypes. @@ -184,7 +184,7 @@ def _create_empty_npu_tensorlist(self, shapes: list, dtypes: list): shapes (list): List of tensor shapes (e.g., [(3,), (2, 4)]) dtypes (list): List of torch dtypes (e.g., [torch.float32, torch.int64]) Returns: - list: List of uninitialized NPU tensors + list[Tensor]: List of uninitialized NPU tensors """ tensors: list[Tensor] = [] for shape, dtype in zip(shapes, dtypes, strict=True): @@ -243,7 +243,7 @@ def supports_put(self, value: Any) -> bool: """Accepts any Python object.""" return True - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> None: """Store objects via zero-copy serialization in batches.""" for i in range(0, len(keys), self.PUT_KEYS_LIMIT): batch_keys = keys[i : i + self.PUT_KEYS_LIMIT] @@ -267,7 +267,7 @@ def supports_clear(self, strategy_tag: str) -> bool: """Matches 'KVClient' strategy tag.""" return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() - def clear(self, keys: list[str]): + def clear(self, keys: list[str]) -> None: """Delete keys in batches.""" for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT): batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT] @@ -433,7 +433,13 @@ def put_task(strategy, indexes): strategy_tags[original_index] = tag return strategy_tags - def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=None) -> list[Any]: + def get( + self, + keys: list[str], + shapes: Optional[list[Any]] = None, + dtypes: Optional[list[Any]] = None, + custom_backend_meta: Optional[list[str]] = None, + ) -> list[Any]: """Retrieves multiple values from remote storage with expected metadata. Requires shape and dtype hints to reconstruct NPU tensors correctly. @@ -472,7 +478,7 @@ def get_task(strategy, indexes): results[original_index] = value return results - def clear(self, keys: list[str], custom_backend_meta=None): + def clear(self, keys: list[str], custom_backend_meta: Optional[list[str]] = None) -> None: """Deletes multiple keys from remote storage. Args: @@ -513,8 +519,8 @@ def _route_to_strategies( The order must correspond to the original keys. selector: A function that determines whether a strategy supports an item. Signature: `(strategy: StorageStrategy, item: Any) -> bool`. - failback: If True, items that don't match any strategy will be ignored (not included in output). - If False, a ValueError will be raised for any unmatched item. + ignore_unmatched: If True, items that don't match any strategy will be ignored (not included in output). + If False, a ValueError will be raised for any unmatched item. Returns: A dictionary mapping each active strategy to a list of indexes in `items` diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 42f17db8..6dfa0935 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -401,19 +401,20 @@ def _generate_keys(field_names: list[str], global_indexes: list[int]) -> list[st return [pfx + sfx for sfx, pfx in itertools.product(keys_suffixes, keys_prefixes)] @staticmethod - def _generate_values(data: TensorDict) -> list[Tensor]: + def _generate_values(data: TensorDict) -> list[Any]: """ - Extract and flatten tensor values from a TensorDict in field-major order. + Extract and flatten values from a TensorDict in field-major order. Values are ordered by sorted field names, then by row (sample) order within each field. This matches the key order generated by `_generate_keys`. Args: - data (TensorDict): Input data where keys are field names and values are tensors. + data (TensorDict): Input data where keys are field names and values are tensors or any type + wrapped by NonTensorStack. Returns: - list[Tensor]: Flattened list of tensors, e.g., - [data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...] + list[Any]: Flattened list of values, e.g., + [data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...] """ - results: list[Tensor] = [] + results: list[Any] = [] for field in sorted(data.keys()): field_data = data[field] if isinstance(field_data, Tensor) and field_data.is_nested: @@ -457,17 +458,17 @@ def _get_executor(self) -> ThreadPoolExecutor: assert self._multi_threads_executor is not None return self._multi_threads_executor - def _merge_tensors_to_tensordict(self, metadata: BatchMeta, values: list[Tensor]) -> TensorDict: + def _merge_tensors_to_tensordict(self, metadata: BatchMeta, values: list[Any]) -> TensorDict: """ Reconstruct a TensorDict from a list of values using metadata. The values list is assumed to be in the same order as keys generated by `_generate_keys`. According to field names and global indexes in metadata, this method can determine - which dict key and which row this tensor belongs to. Then it reshapes the flat tensors list + which dict key and which row this value belongs to. Then it reshapes the flat values list back into a structured TensorDict . Args: metadata (BatchMeta): Metadata containing global indexes and field names. - values (list[Tensor]): List of tensors in field-major order. + values (list[Any]): List of values in field-major order. Returns: TensorDict: Reconstructed tensor dictionary with batch size equal to number of samples. """ @@ -534,7 +535,9 @@ def process_field(field_idx: int): return TensorDict(merged_data, batch_size=num_samples) @staticmethod - def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): + def _get_shape_type_custom_backend_meta_list( + metadata: BatchMeta, + ) -> tuple[list[torch.Size], list[torch.dtype], list[Any]]: """ Extract the expected shape, dtype, and custom_backend_meta for each field-sample pair in metadata. The order matches the key/value order: sorted by field name, then by global index.