Support storage unit in TransferQueue#1
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR adds a storage unit component to the TransferQueue experimental feature, implementing a distributed storage system with ZMQ-based communication between storage units and controllers.
Key Changes
- Implements
StorageUnitDataclass for managing field-based data storage with validation - Adds
TransferQueueStorageSimpleUnitas a Ray remote actor for distributed storage operations - Establishes ZMQ-based communication protocol for PUT/GET/CLEAR operations with controllers
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| verl/experimental/transfer_queue/storage.py | New storage implementation with data management and ZMQ communication |
| verl/experimental/transfer_queue/init.py | Package initialization file with copyright header |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| from transfer_queue.utils.utils import TransferQueueRole | ||
| from transfer_queue.utils.zmq_utils import ( | ||
| ZMQMessage, | ||
| ZMQRequestType, | ||
| ZMQServerInfo, | ||
| create_zmq_socket, | ||
| get_free_port, | ||
| ) |
There was a problem hiding this comment.
The import statements use relative paths starting with 'transfer_queue' but should use absolute imports from the package. These should be 'verl.experimental.transfer_queue.utils.utils' and 'verl.experimental.transfer_queue.utils.zmq_utils'.
| else: | ||
| result[field] = gathered_item.unsqueeze(0) | ||
| else: | ||
| gathered_items = list(itemgetter(*local_indexes)(self.field_data[field])) |
There was a problem hiding this comment.
Using itemgetter with unpacked local_indexes will fail when local_indexes has only one element, as itemgetter(single_value) returns the item directly, not a tuple. This inconsistency with the multi-item case could cause issues.
| gathered_items = list(itemgetter(*local_indexes)(self.field_data[field])) | |
| # Ensure gathered_items is always a list, even if local_indexes has one element | |
| if len(local_indexes) == 1: | |
| gathered_items = [self.field_data[field][local_indexes[0]]] | |
| else: | |
| gathered_items = list(itemgetter(*local_indexes)(self.field_data[field])) |
| if gathered_items: | ||
| all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items) | ||
| if all_tensors: | ||
| result[field] = torch.nested.as_nested_tensor(gathered_items) |
There was a problem hiding this comment.
Creating nested tensors can be expensive and may not be the most efficient representation for storage operations. Consider using torch.stack() when tensors have compatible shapes, falling back to nested tensors only when necessary.
| result[field] = torch.nested.as_nested_tensor(gathered_items) | |
| try: | |
| result[field] = torch.stack(gathered_items) | |
| except RuntimeError: | |
| result[field] = torch.nested.as_nested_tensor(gathered_items) |
| per_tensor_dtypes: dict[int, torch.dtype] = {} | ||
| per_tensor_shapes: dict[int, torch.Size] = {} |
There was a problem hiding this comment.
The type annotations are incorrect. Based on the code below, these dictionaries should be dict[int, dict[str, torch.dtype]] and dict[int, dict[str, torch.Size]] respectively, as they store nested dictionaries mapping field names to dtypes/shapes.
| per_tensor_dtypes: dict[int, torch.dtype] = {} | |
| per_tensor_shapes: dict[int, torch.Size] = {} | |
| per_tensor_dtypes: dict[int, dict[str, torch.dtype]] = {} | |
| per_tensor_shapes: dict[int, dict[str, torch.Size]] = {} |
| per_tensor_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None | ||
| per_tensor_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None |
There was a problem hiding this comment.
Using hasattr() for dtype and shape checks is fragile. Consider using isinstance() checks for torch.Tensor or other expected types to make the code more explicit and maintainable.
| per_tensor_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None | |
| per_tensor_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None | |
| per_tensor_dtypes[global_idx][field] = data_item.dtype if isinstance(data_item, torch.Tensor) else None | |
| per_tensor_shapes[global_idx][field] = data_item.shape if isinstance(data_item, torch.Tensor) else None |
What does this PR do?
Support storage unit in TransferQueue
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
Not related.
API and Usage Example
Not related.
Design & Code Changes
Not related.
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)