Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 68 additions & 67 deletions transfer_queue/storage/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import ray
import torch
import zmq
import zmq.asyncio
from omegaconf import DictConfig
from tensordict import NonTensorStack, TensorDict
from torch import Tensor
Expand Down Expand Up @@ -65,10 +66,10 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig):
self.config = config
self.controller_info = controller_info

self.data_status_update_socket: Optional[zmq.Socket[bytes]] = None
self.controller_handshake_socket: Optional[zmq.Socket[bytes]] = None
# Handshake socket is sync (used only during initialization)
self.controller_handshake_socket: Optional[zmq.Socket] = None

self.zmq_context: Optional[zmq.Context[Any]] = None
self.zmq_context: Optional[zmq.asyncio.Context] = None
self._connect_to_controller()

def _connect_to_controller(self) -> None:
Expand All @@ -77,26 +78,28 @@ def _connect_to_controller(self) -> None:
raise ValueError(f"controller_info should be ZMQServerInfo, but got {type(self.controller_info)}")

try:
# create zmq context
self.zmq_context = zmq.Context()
# Create a synchronous context for handshake (blocking operation)
sync_zmq_context = zmq.Context()

# create zmq sockets for handshake and data status update
# create zmq socket for handshake (sync, for initial connection)
self.controller_handshake_socket = create_zmq_socket(
self.zmq_context,
sync_zmq_context,
zmq.DEALER,
identity=f"{self.storage_manager_id}-controller_handshake_socket-{uuid4().hex[:8]}".encode(),
)
self.data_status_update_socket = create_zmq_socket(
self.zmq_context,
zmq.DEALER,
identity=f"{self.storage_manager_id}-data_status_update_socket-{uuid4().hex[:8]}".encode(),
)
assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized"
self.data_status_update_socket.connect(self.controller_info.to_addr("data_status_update_socket"))

# do handshake with controller
# do handshake with controller using sync socket
self._do_handshake_with_controller()

# close the sync handshake socket and context after handshake
if self.controller_handshake_socket and not self.controller_handshake_socket.closed:
self.controller_handshake_socket.close(linger=0)
self.controller_handshake_socket = None
sync_zmq_context.term()

# create async context for data status update
self.zmq_context = zmq.asyncio.Context()

except Exception as e:
logger.error(f"Failed to connect to controller: {e}")
raise
Expand Down Expand Up @@ -210,22 +213,19 @@ async def notify_data_update(
shapes: Per-field shapes for each field, in {global_index: {field: shape}} format.
custom_backend_meta: Per-field custom_meta for each sample, in {global_index: {field: custom_meta}} format.
"""
# Create zmq poller for notifying data update information

if not self.controller_info:
logger.warning(f"No controller connected for storage manager {self.storage_manager_id}")
return

# Create zmq poller for notifying data update information
poller = zmq.Poller()
# Note: data_status_update_socket is already connected during initialization
assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized"
# create dynamic socket
identity = f"{self.storage_manager_id}-data_update-{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(self.zmq_context, zmq.DEALER, identity=identity)

try:
poller.register(self.data_status_update_socket, zmq.POLLIN)
sock.connect(self.controller_info.to_addr("data_status_update_socket"))

request_msg = ZMQMessage.create(
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type]
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE,
sender_id=self.storage_manager_id,
body={
"partition_id": partition_id,
Expand All @@ -237,51 +237,54 @@ async def notify_data_update(
},
).serialize()

self.data_status_update_socket.send_multipart(request_msg)
await sock.send_multipart(request_msg)
logger.debug(
f"[{self.storage_manager_id}]: Send data status update request "
f"from storage manager id #{self.storage_manager_id} "
f"to controller id #{self.controller_info.id} successfully."
)
except Exception as e:
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={
"message": f"Failed to notify data status update information from "
f"storage manager id #{self.storage_manager_id}, "
f"detail error message: {str(e)}"
},
).serialize()

self.data_status_update_socket.send_multipart(request_msg)

# Make sure controller successfully receives data status update information.
response_received: bool = False
start_time = time.time()
response_received = False
timeout = TQ_DATA_UPDATE_RESPONSE_TIMEOUT

while (
not response_received # Only one controller to get response from
and time.time() - start_time < TQ_DATA_UPDATE_RESPONSE_TIMEOUT
):
socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
while not response_received and timeout > 0:
try:
poll_interval = min(TQ_STORAGE_POLLER_TIMEOUT, timeout)
messages = await asyncio.wait_for(sock.recv_multipart(), timeout=poll_interval)
response_msg = ZMQMessage.deserialize(messages)

if self.data_status_update_socket in socks:
response_msg = ZMQMessage.deserialize(self.data_status_update_socket.recv_multipart())
if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK:
response_received = True
logger.debug(
f"[{self.storage_manager_id}]: Get data status update ACK response "
f"from controller id #{response_msg.sender_id} successfully."
)
except asyncio.TimeoutError:
timeout -= poll_interval
except Exception as e:
logger.warning(f"[{self.storage_manager_id}]: Error receiving response: {e}")
break

if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK:
response_received = True
logger.debug(
f"[{self.storage_manager_id}]: Get data status update ACK response "
f"from controller id #{response_msg.sender_id} "
f"to storage manager id #{self.storage_manager_id} successfully."
)
if not response_received:
logger.error(f"[{self.storage_manager_id}]: Did not receive data status update ACK.")

if not response_received:
logger.error(
f"[{self.storage_manager_id}]: Storage manager id #{self.storage_manager_id} "
f"did not receive data status update ACK response from controller."
)
except Exception as e:
logger.error(f"[{self.storage_manager_id}]: Error during notify_data_update: {e}")
try:
error_msg = ZMQMessage.create(
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR,
sender_id=self.storage_manager_id,
body={"message": f"Failed to notify: {str(e)}"},
).serialize()
await sock.send_multipart(error_msg)
except Exception:
pass
finally:
try:
if not sock.closed:
sock.close(linger=-1)
except Exception:
pass

@abstractmethod
async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
Expand Down Expand Up @@ -319,18 +322,16 @@ async def clear_data(self, metadata: BatchMeta) -> None:

def close(self) -> None:
"""Close all ZMQ sockets and context to prevent resource leaks."""
for sock in (self.controller_handshake_socket, self.data_status_update_socket):
# Close handshake socket if it exists
if self.controller_handshake_socket:
try:
if sock and not sock.closed:
sock.close(linger=0)
if not self.controller_handshake_socket.closed:
self.controller_handshake_socket.close(linger=0)
except Exception as e:
logger.error(f"[{self.storage_manager_id}]: Error closing socket {sock}: {str(e)}")
logger.error(f"[{self.storage_manager_id}]: Error closing controller_handshake_socket: {str(e)}")

try:
if self.zmq_context:
self.zmq_context.term()
except Exception as e:
logger.error(f"[{self.storage_manager_id}]: Error terminating zmq_context: {str(e)}")
if self.zmq_context:
self.zmq_context.term()

def __del__(self):
"""Destructor to ensure resources are cleaned up."""
Expand Down
16 changes: 8 additions & 8 deletions transfer_queue/storage/managers/simple_backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
logger.addHandler(handler)

TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT", 200)) # seconds
TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT", 200)) # seconds
TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds

TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False)

Expand Down Expand Up @@ -119,11 +118,12 @@ def _build_storage_mapping_functions(self):

# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
@staticmethod
def dynamic_storage_manager_socket(socket_name: str):
def dynamic_storage_manager_socket(socket_name: str, timeout: int):
"""Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close).

Args:
socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port").
timeout (float): Timeout in seconds for ZMQ connection (in seconds).

Decorated Function Rules:
1. Must be an async class method (needs `self`).
Expand Down Expand Up @@ -157,8 +157,8 @@ async def wrapper(self, *args, **kwargs):
try:
sock.connect(address)
# Timeouts to avoid indefinite await on recv/send
sock.setsockopt(zmq.RCVTIMEO, TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT * 1000)
sock.setsockopt(zmq.SNDTIMEO, TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT * 1000)
sock.setsockopt(zmq.RCVTIMEO, timeout * 1000)
sock.setsockopt(zmq.SNDTIMEO, timeout * 1000)
logger.debug(
f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} "
f"with identity {identity.decode()}"
Expand Down Expand Up @@ -249,7 +249,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
partition_id, list(results.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes
)

@dynamic_storage_manager_socket(socket_name="put_get_socket")
@dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT)
async def _put_to_single_storage_unit(
self,
local_indexes: list[int],
Expand Down Expand Up @@ -348,7 +348,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict:

return TensorDict(tensor_data, batch_size=len(metadata))

@dynamic_storage_manager_socket(socket_name="put_get_socket")
@dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT)
async def _get_from_single_storage_unit(
self, storage_meta_group: StorageMetaGroup, target_storage_unit: str, socket: zmq.Socket = None
):
Expand Down Expand Up @@ -407,7 +407,7 @@ async def clear_data(self, metadata: BatchMeta) -> None:
if isinstance(result, Exception):
logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}")

@dynamic_storage_manager_socket(socket_name="put_get_socket")
@dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT)
async def _clear_single_storage_unit(self, local_indexes, target_storage_unit=None, socket=None):
try:
request_msg = ZMQMessage.create(
Expand Down