From 1465fac54245765967abe136200d8aa1d4e7f407 Mon Sep 17 00:00:00 2001 From: hzy_19 Date: Wed, 28 Jan 2026 20:32:12 +0800 Subject: [PATCH 1/2] add back non-zero-copy capability Signed-off-by: hzy_19 --- .../managers/simple_backend_manager.py | 11 +++- transfer_queue/utils/zmq_utils.py | 52 +++++++++++-------- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 250438ca..6d2a8e41 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -30,6 +30,7 @@ from transfer_queue.storage.managers.base import TransferQueueStorageManager from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory from transfer_queue.storage.simple_backend import StorageMetaGroup +from transfer_queue.utils.utils import get_env_bool from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket logger = logging.getLogger(__name__) @@ -44,6 +45,8 @@ TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT", 200)) # seconds TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT", 200)) # seconds +TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) + @TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager") class AsyncSimpleStorageManager(TransferQueueStorageManager): @@ -236,7 +239,7 @@ async def _put_to_single_storage_unit( """ request_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_DATA, + request_type=ZMQRequestType.PUT_DATA, # type: ignore[arg-type] sender_id=self.storage_manager_id, receiver_id=target_storage_unit, body={"local_indexes": local_indexes, "data": storage_data}, @@ -331,7 +334,7 @@ async def _get_from_single_storage_unit( fields = storage_meta_group.get_field_names() request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_DATA, + request_type=ZMQRequestType.GET_DATA, # type: ignore[arg-type] sender_id=self.storage_manager_id, receiver_id=target_storage_unit, body={"local_indexes": local_indexes, "fields": fields}, @@ -452,6 +455,10 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) result = (result,) results[fname] = list(result) + if not TQ_ZERO_COPY_SERIALIZATION: + # Explicitly copy tensor slices to prevent pickling the whole tensor for every storage unit. + # The tensors may still be contiguous, so we cannot use .contiguous() to trigger copy from parent tensors. + results[fname] = [item.clone() if isinstance(item, torch.Tensor) else item for item in results[fname]] return results diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 4887d073..b19fd603 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -15,6 +15,7 @@ import logging import os +import pickle import socket import time from dataclasses import dataclass @@ -28,6 +29,7 @@ from transfer_queue.utils.utils import ( ExplicitEnum, TransferQueueRole, + get_env_bool, ) logger = logging.getLogger(__name__) @@ -42,6 +44,8 @@ bytestr: TypeAlias = bytes | bytearray | memoryview +TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) + class ZMQRequestType(ExplicitEnum): """ @@ -155,36 +159,42 @@ def create( def serialize(self) -> list: """ - Serialize message using unified MsgpackEncoder. - Returns: list[bytestr] - [msgpack_header, *tensor_buffers] + Serialize message using unified MsgpackEncoder or pickle. + Returns: list[bytestr] - [msgpack_header, *tensor_buffers] or [bytes] """ - msg_dict = { - "request_type": self.request_type.value, # Enum -> str for msgpack - "sender_id": self.sender_id, - "receiver_id": self.receiver_id, - "request_id": self.request_id, - "timestamp": self.timestamp, - "body": self.body, - } - return list(_encoder.encode(msg_dict)) + if TQ_ZERO_COPY_SERIALIZATION: + msg_dict = { + "request_type": self.request_type.value, # Enum -> str for msgpack + "sender_id": self.sender_id, + "receiver_id": self.receiver_id, + "request_id": self.request_id, + "timestamp": self.timestamp, + "body": self.body, + } + return list(_encoder.encode(msg_dict)) + else: + return [pickle.dumps(self)] @classmethod def deserialize(cls, frames: list) -> "ZMQMessage": """ - Deserialize message using unified MsgpackDecoder. + Deserialize message using unified MsgpackDecoder or pickle. """ if not frames: raise ValueError("Empty frames received") - msg_dict = _decoder.decode(frames) - return cls( - request_type=ZMQRequestType(msg_dict["request_type"]), - sender_id=msg_dict["sender_id"], - receiver_id=msg_dict["receiver_id"], - body=msg_dict["body"], - request_id=msg_dict["request_id"], - timestamp=msg_dict["timestamp"], - ) + if TQ_ZERO_COPY_SERIALIZATION: + msg_dict = _decoder.decode(frames) + return cls( + request_type=ZMQRequestType(msg_dict["request_type"]), + sender_id=msg_dict["sender_id"], + receiver_id=msg_dict["receiver_id"], + body=msg_dict["body"], + request_id=msg_dict["request_id"], + timestamp=msg_dict["timestamp"], + ) + else: + return pickle.loads(frames[0]) def get_free_port() -> str: From d6f599fb5a83f9386325cdd56bcbc589bb630a54 Mon Sep 17 00:00:00 2001 From: hzy_19 Date: Wed, 28 Jan 2026 20:43:43 +0800 Subject: [PATCH 2/2] fix Signed-off-by: hzy_19 --- transfer_queue/storage/managers/simple_backend_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 6d2a8e41..eb130675 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -455,10 +455,10 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) result = (result,) results[fname] = list(result) - if not TQ_ZERO_COPY_SERIALIZATION: - # Explicitly copy tensor slices to prevent pickling the whole tensor for every storage unit. - # The tensors may still be contiguous, so we cannot use .contiguous() to trigger copy from parent tensors. - results[fname] = [item.clone() if isinstance(item, torch.Tensor) else item for item in results[fname]] + if not TQ_ZERO_COPY_SERIALIZATION: + # Explicitly copy tensor slices to prevent pickling the whole tensor for every storage unit. + # The tensors may still be contiguous, so we cannot use .contiguous() to trigger copy from parent tensors. + results[fname] = [item.clone() if isinstance(item, torch.Tensor) else item for item in results[fname]] return results