From f6ab22e66f9085cbb961b64a4f5cedbfc0931d78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9C=8B=E6=88=9172=E9=81=8D?= Date: Fri, 6 Mar 2026 10:44:51 +0800 Subject: [PATCH] refactor: convert BatchMeta to columnar layout with zero-copy serialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Convert BatchMeta/KVBatchMeta to columnar list layout for zero-copy serialization - Add columnar custom_meta and _custom_backend_meta support - Add with_data_fields to BatchMeta; fix cross-shard e2e test - Add CUSTOM_TYPE_NUMPY for native numpy round-trip in serial_utils - Apply code review fixes from columnar-batchmeta branch review - Simplify storage manager: extract helpers, rename variables for clarity - Rename local_indexes/gi_list to global_indexes across codebase - Remove unused StorageMetaGroup dead code - Replace deepcopy with shallow copy in BatchMeta.__post_init__ - Rewrite concat extra_info merge to batch-level semantics - Replace chunk-based routing with deterministic hash routing - Detect dtype/shape changes in field_schema_cache - Make _SampleView a complete read-only single-sample view - Remove to_dict/from_dict/_parse_dtype, use direct pickle for BatchMeta - Rename encode/decode_with_fallback to encode/decode Signed-off-by: 看我72遍 --- scripts/put_benchmark.py | 12 +- tests/e2e/test_e2e_lifecycle_consistency.py | 77 +- tests/test_async_simple_storage_manager.py | 363 +++-- tests/test_client.py | 73 +- tests/test_controller.py | 45 +- tests/test_kv_storage_manager.py | 72 +- tests/test_metadata.py | 1324 +++++------------ tests/test_ray_p2p.py | 27 +- tests/test_serial_utils_on_cpu.py | 165 +- tests/test_simple_storage_unit.py | 81 +- tests/test_yuanrong_client_zero_copy.py | 2 + transfer_queue/client.py | 10 +- transfer_queue/controller.py | 180 ++- transfer_queue/metadata.py | 1093 +++++++------- transfer_queue/storage/__init__.py | 3 +- transfer_queue/storage/managers/base.py | 124 +- .../managers/simple_backend_manager.py | 364 ++--- transfer_queue/storage/simple_backend.py | 195 +-- transfer_queue/utils/serial_utils.py | 128 +- transfer_queue/utils/zmq_utils.py | 59 +- tutorial/03_metadata_concepts.py | 381 ++--- 21 files changed, 2166 insertions(+), 2612 deletions(-) diff --git a/scripts/put_benchmark.py b/scripts/put_benchmark.py index 1700d55e..6b2afb59 100644 --- a/scripts/put_benchmark.py +++ b/scripts/put_benchmark.py @@ -33,13 +33,11 @@ parent_dir = Path(__file__).resolve().parent.parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - AsyncTransferQueueClient, - SimpleStorageUnit, - TransferQueueController, - process_zmq_server_info, -) +from transfer_queue import TransferQueueClient # noqa: E402 +from transfer_queue.controller import TransferQueueController # noqa: E402 +from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402 from transfer_queue.utils.common import get_placement_group # noqa: E402 +from transfer_queue.utils.zmq_utils import process_zmq_server_info # noqa: E402 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -309,7 +307,7 @@ def initialize_system(self, config_dict): self.tq_config = OmegaConf.merge(tq_internal_conf, self.tq_config) # Client Init - self.data_system_client = AsyncTransferQueueClient( + self.data_system_client = TransferQueueClient( client_id="Trainer", controller_info=self.data_system_controller_info ) self.data_system_client.initialize_storage_manager( diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index 4e7025f0..22b45c5c 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -420,15 +420,22 @@ def test_cross_shard_complex_update(e2e_client): "Region 30-39 tensor_f32 should match original Put B" ) - # 9. Verify new fields exist in update region - extended_fields = base_fields + ["new_extra_tensor", "new_extra_non_tensor"] - update_region_meta = poll_for_meta( - client, partition_id, extended_fields, 20, "update_region_task", mode="force_fetch" + # 9. Verify new fields exist in update region (indices 10-29 only have new fields). + # Build extended_meta from full_meta (which has valid _custom_backend_meta) + # by selecting the subset of samples whose global_indexes match meta_update. + # Using meta_update directly would fail because it was derived from alloc_meta + # before put(), so its _custom_backend_meta may be incomplete. + update_gis = set(meta_update.global_indexes) + update_positions_in_full = [ + i for i, global_index in enumerate(full_meta.global_indexes) if global_index in update_gis + ] + update_meta_with_backend = full_meta.select_samples(update_positions_in_full) + extended_meta = update_meta_with_backend.with_data_fields( + base_fields + ["new_extra_tensor", "new_extra_non_tensor"] ) - if update_region_meta is not None and update_region_meta.size > 0: - update_region_data = client.get_data(update_region_meta) - assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist" - assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist" + update_region_data = client.get_data(extended_meta) + assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist" + assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist" finally: client.clear_partition(partition_id) @@ -641,5 +648,59 @@ def test_clear_partition(e2e_client): pass +# Scenario Six: Dynamic Tensor Shape → Nested Tensor Transition +def test_dynamic_tensor_shape_nested_transition(e2e_client): + """ + Test transition from regular tensor to nested tensor. + First put tensors of identical shape, then put tensors of a different shape. + Verify that the field schema marks is_nested=True, and getting all samples returns a nested tensor. + """ + client = e2e_client + partition_id = "test_nested_transition_partition" + task_name = "test_task" + + try: + # 1. Put same-shape tensor (shape: (2, 4)) — initial insert + data1 = TensorDict({"dynamic_feature": torch.ones(2, 4)}, batch_size=2) + meta1_put = client.put(data=data1, partition_id=partition_id) + assert meta1_put.size == 2 + + # Poll and verify first batch is regular tensor + meta1 = poll_for_meta(client, partition_id, ["dynamic_feature"], 2, task_name, mode="force_fetch") + assert not meta1.field_schema["dynamic_feature"]["is_nested"] + retrieved_1 = client.get_data(meta1) + assert not retrieved_1["dynamic_feature"].is_nested + assert retrieved_1["dynamic_feature"].shape == (2, 4) + + # 2. Allocate 2 more slots via insert mode, put different-shape tensor (shape: (2, 6)) + alloc_meta2 = client.get_meta( + partition_id=partition_id, + data_fields=["dynamic_feature"], + batch_size=2, + mode="insert", + task_name="allocator", + ) + assert alloc_meta2.size == 2 + data2 = TensorDict({"dynamic_feature": torch.ones(2, 6)}, batch_size=2) + client.put(data=data2, metadata=alloc_meta2) + + # Poll and verify metadata now indicates nested tensor + meta2 = poll_for_meta(client, partition_id, ["dynamic_feature"], 2, task_name, mode="force_fetch") + + # After second put with different shape, is_nested should be True + assert meta2.field_schema["dynamic_feature"]["is_nested"] is True + + # 3. Retrieve all 4 samples together + meta_all = poll_for_meta(client, partition_id, ["dynamic_feature"], 4, task_name, mode="force_fetch") + assert meta_all.field_schema["dynamic_feature"]["is_nested"] is True + + retrieved_all = client.get_data(meta_all) + # The merged result should be a nested tensor since the shapes vary + assert retrieved_all["dynamic_feature"].is_nested is True + assert len(retrieved_all["dynamic_feature"]) == 4 + finally: + client.clear_partition(partition_id) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 8254895a..ed532aa7 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -17,6 +17,7 @@ from pathlib import Path from unittest.mock import AsyncMock, Mock, patch +import numpy as np import pytest import pytest_asyncio import torch @@ -27,7 +28,7 @@ parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 +from transfer_queue.metadata import BatchMeta # noqa: E402 from transfer_queue.storage import AsyncSimpleStorageManager # noqa: E402 from transfer_queue.utils.enum_utils import TransferQueueRole # noqa: E402 from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo # noqa: E402 @@ -79,11 +80,6 @@ async def mock_async_storage_manager(): manager.controller_handshake_socket = None manager.zmq_context = None - # Add mapping functions - storage_unit_keys = list(storage_unit_infos.keys()) - manager.global_index_storage_unit_mapping = lambda x: storage_unit_keys[x % len(storage_unit_keys)] - manager.global_index_local_index_mapping = lambda x: x // len(storage_unit_keys) - # Mock essential methods manager._connect_to_controller = mock_connect @@ -100,41 +96,31 @@ async def test_async_storage_manager_initialization(mock_async_storage_manager): assert "storage_0" in manager.storage_unit_infos assert "storage_1" in manager.storage_unit_infos - # Test mapping functions - assert manager.global_index_storage_unit_mapping(0) == "storage_0" - assert manager.global_index_storage_unit_mapping(1) == "storage_1" - assert manager.global_index_local_index_mapping(0) == 0 - assert manager.global_index_local_index_mapping(3) == 1 - @pytest.mark.asyncio async def test_async_storage_manager_mock_operations(mock_async_storage_manager): """Test AsyncSimpleStorageManager operations with mocked ZMQ.""" manager = mock_async_storage_manager - # Create test metadata - sample_metas = [ - SampleMeta( - partition_id="0", - global_index=0, - fields={ - "test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)), - }, - ), - SampleMeta( - partition_id="0", - global_index=1, - fields={ - "test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)), - }, - ), - ] - batch_meta = BatchMeta(samples=sample_metas) + # Create test metadata using columnar API + batch_meta = BatchMeta( + global_indexes=[0, 1], + partition_ids=["0", "0"], + field_schema={ + "test_field": { + "dtype": torch.float32, + "shape": (2,), + "is_nested": False, + "is_non_tensor": False, + } + }, + production_status=np.ones(2, dtype=np.int8), + ) # Create test data test_data = TensorDict( { - "test_field": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])], + "test_field": torch.stack([torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]), }, batch_size=2, ) @@ -163,92 +149,6 @@ async def test_async_storage_manager_mock_operations(mock_async_storage_manager) await manager.clear_data(batch_meta) -@pytest.mark.asyncio -async def test_async_storage_manager_mapping_functions(): - """Test AsyncSimpleStorageManager mapping functions.""" - - # Mock storage unit infos - storage_unit_infos = { - "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, - id="storage_0", - ip="127.0.0.1", - ports={"put_get_socket": 12345}, - ), - "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, - id="storage_1", - ip="127.0.0.1", - ports={"put_get_socket": 12346}, - ), - "storage_2": ZMQServerInfo( - role=TransferQueueRole.STORAGE, - id="storage_2", - ip="127.0.0.1", - ports={"put_get_socket": 12347}, - ), - } - - # Mock controller info - controller_info = ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, - id="controller_0", - ip="127.0.0.1", - ports={"handshake_socket": 12348, "data_status_update_socket": 12349}, - ) - - config = { - "zmq_info": storage_unit_infos, - } - - # Mock ZMQ operations - with ( - patch("transfer_queue.storage.managers.base.create_zmq_socket") as mock_create_socket, - patch("zmq.Poller") as mock_poller, - ): - # Create mock socket with proper sync methods - mock_socket = Mock() - mock_socket.connect = Mock() # sync method - mock_socket.send = Mock() # sync method - mock_create_socket.return_value = mock_socket - - # Mock poller with sync methods - mock_poller_instance = Mock() - mock_poller_instance.register = Mock() # sync method - # Return mock socket in poll to simulate handshake response - mock_poller_instance.poll = Mock(return_value=[(mock_socket, zmq.POLLIN)]) # sync method - mock_poller.return_value = mock_poller_instance - - # Mock handshake response - handshake_response = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE_ACK, # type: ignore[arg-type] - sender_id="controller_0", - body={"message": "Handshake successful"}, - ) - mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize()) - - # Create manager - manager = AsyncSimpleStorageManager(controller_info, config) - - # Test round-robin mapping for 3 storage units - # global_index -> storage_unit mapping: 0->storage_0, 1->storage_1, 2->storage_2, - # 3->storage_0, 4->storage_1, ... - assert manager.global_index_storage_unit_mapping(0) == "storage_0" - assert manager.global_index_storage_unit_mapping(1) == "storage_1" - assert manager.global_index_storage_unit_mapping(2) == "storage_2" - assert manager.global_index_storage_unit_mapping(3) == "storage_0" - assert manager.global_index_storage_unit_mapping(4) == "storage_1" - assert manager.global_index_storage_unit_mapping(5) == "storage_2" - - # global_index -> local_index mapping: global_index // num_storage_units - assert manager.global_index_local_index_mapping(0) == 0 - assert manager.global_index_local_index_mapping(1) == 0 - assert manager.global_index_local_index_mapping(2) == 0 - assert manager.global_index_local_index_mapping(3) == 1 - assert manager.global_index_local_index_mapping(4) == 1 - assert manager.global_index_local_index_mapping(5) == 1 - - @pytest.mark.asyncio async def test_async_storage_manager_error_handling(): """Test AsyncSimpleStorageManager error handling.""" @@ -310,22 +210,25 @@ async def test_async_storage_manager_error_handling(): manager._clear_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock CLEAR error")) manager.notify_data_update = AsyncMock() - # Create test metadata - sample_metas = [ - SampleMeta( - partition_id="0", - global_index=0, - fields={ - "test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)), - }, - ), - ] - batch_meta = BatchMeta(samples=sample_metas) + # Create test metadata using columnar API + batch_meta = BatchMeta( + global_indexes=[0], + partition_ids=["0"], + field_schema={ + "test_field": { + "dtype": torch.float32, + "shape": (2,), + "is_nested": False, + "is_non_tensor": False, + } + }, + production_status=np.ones(1, dtype=np.int8), + ) # Create test data test_data = TensorDict( { - "test_field": [torch.tensor([1.0, 2.0])], + "test_field": torch.tensor([[1.0, 2.0]]), }, batch_size=1, ) @@ -340,3 +243,205 @@ async def test_async_storage_manager_error_handling(): # Note: clear_data uses return_exceptions=True, so it doesn't raise exceptions directly # Instead, we can verify that the clear operation was attempted await manager.clear_data(batch_meta) # Should not raise due to return_exceptions=True + + +@pytest.mark.asyncio +async def test_get_data_routes_from_hash(): + """get_data should route using global_idx % num_su (hash routing).""" + storage_unit_infos = { + "storage_0": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_0", + ip="127.0.0.1", + ports={"put_get_socket": 19010}, + ), + "storage_1": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_1", + ip="127.0.0.1", + ports={"put_get_socket": 19011}, + ), + } + with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) + manager.storage_manager_id = "test_get" + manager.storage_unit_infos = storage_unit_infos + manager.controller_info = None + manager.data_status_update_socket = None + manager.controller_handshake_socket = None + manager.zmq_context = None + + # global_index 0,2 → storage_0 (even % 2 = 0); 1,3 → storage_1 (odd % 2 = 1) + batch_meta = BatchMeta( + global_indexes=[0, 1, 2, 3], + partition_ids=["p0"] * 4, + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(4, dtype=np.int8), + ) + + # Mock _get_from_single_storage_unit to record which su_id and global_index were requested + called_with: dict[str, list] = {} + + async def fake_get(global_indexes, fields, target_storage_unit=None, **kwargs): + su = target_storage_unit + called_with[su] = list(global_indexes) + tensors = [torch.zeros(2) for _ in global_indexes] + return global_indexes, fields, {"f": tensors}, b"" + + manager._get_from_single_storage_unit = fake_get + + await manager.get_data(batch_meta) + + assert "storage_0" in called_with, "storage_0 was not called by get" + assert "storage_1" in called_with, "storage_1 was not called by get" + assert set(called_with["storage_0"]) == {0, 2} + assert set(called_with["storage_1"]) == {1, 3} + + +@pytest.mark.asyncio +async def test_clear_data_routes_from_hash(): + """clear_data should route using global_idx % num_su (hash routing).""" + storage_unit_infos = { + "storage_0": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_0", + ip="127.0.0.1", + ports={"put_get_socket": 19020}, + ), + "storage_1": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_1", + ip="127.0.0.1", + ports={"put_get_socket": 19021}, + ), + } + with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) + manager.storage_manager_id = "test_clear" + manager.storage_unit_infos = storage_unit_infos + manager.controller_info = None + manager.data_status_update_socket = None + manager.controller_handshake_socket = None + manager.zmq_context = None + + # global_index 0,2 → storage_0 (even); 1,3 → storage_1 (odd) + batch_meta = BatchMeta( + global_indexes=[0, 1, 2, 3], + partition_ids=["p0"] * 4, + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(4, dtype=np.int8), + ) + + called_with: dict[str, list] = {} + + async def fake_clear(global_indexes, target_storage_unit=None, **kwargs): + called_with[target_storage_unit] = list(global_indexes) + + manager._clear_single_storage_unit = fake_clear + + await manager.clear_data(batch_meta) + + assert set(called_with.get("storage_0", [])) == {0, 2} + assert set(called_with.get("storage_1", [])) == {1, 3} + + +@pytest.mark.asyncio +async def test_hash_routing_stable_across_batch_sizes(): + """Hash routing must produce the same SU assignment regardless of batch size. + + Put 10 samples in one batch vs two batches of 5 — each global_idx must route + to the same SU in both cases. + """ + storage_unit_infos = { + "storage_0": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_0", + ip="127.0.0.1", + ports={"put_get_socket": 19030}, + ), + "storage_1": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_1", + ip="127.0.0.1", + ports={"put_get_socket": 19031}, + ), + } + with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) + manager.storage_manager_id = "test_hash_batch" + manager.storage_unit_infos = storage_unit_infos + manager.controller_info = None + manager.data_status_update_socket = None + manager.controller_handshake_socket = None + manager.zmq_context = None + + all_indexes = list(range(10)) + full_routing = manager._group_by_hash(all_indexes) + + # Build per-index mapping from the full-batch result + idx_to_su_full: dict[int, str] = {} + for su_id, gi_list in full_routing.items(): + for gi in gi_list: + idx_to_su_full[gi] = su_id + + # Route as two batches of 5 + batch_a_routing = manager._group_by_hash(all_indexes[:5]) + batch_b_routing = manager._group_by_hash(all_indexes[5:]) + + idx_to_su_split: dict[int, str] = {} + for su_id, gi_list in batch_a_routing.items(): + for gi in gi_list: + idx_to_su_split[gi] = su_id + for su_id, gi_list in batch_b_routing.items(): + for gi in gi_list: + idx_to_su_split[gi] = su_id + + assert idx_to_su_full == idx_to_su_split, ( + f"Routing differs between full batch and split batches:\n full: {idx_to_su_full}\n split: {idx_to_su_split}" + ) + + +@pytest.mark.asyncio +async def test_hash_routing_stable_reversed_order(): + """Hash routing must produce the same SU assignment regardless of key order. + + Forward order [0..9] and reversed order [9..0] must yield identical routing. + """ + storage_unit_infos = { + "storage_0": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_0", + ip="127.0.0.1", + ports={"put_get_socket": 19040}, + ), + "storage_1": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_1", + ip="127.0.0.1", + ports={"put_get_socket": 19041}, + ), + } + with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) + manager.storage_manager_id = "test_hash_order" + manager.storage_unit_infos = storage_unit_infos + manager.controller_info = None + manager.data_status_update_socket = None + manager.controller_handshake_socket = None + manager.zmq_context = None + + forward = list(range(10)) + reversed_indexes = list(reversed(forward)) + + routing_fwd = manager._group_by_hash(forward) + routing_rev = manager._group_by_hash(reversed_indexes) + + # Build per-index mapping + def _to_idx_map(routing): + m = {} + for su_id, gi_list in routing.items(): + for gi in gi_list: + m[gi] = su_id + return m + + assert _to_idx_map(routing_fwd) == _to_idx_map(routing_rev), "Hash routing should be order-independent" diff --git a/tests/test_client.py b/tests/test_client.py index 6729196f..ccf8e8b9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -31,10 +31,8 @@ from transfer_queue import TransferQueueClient # noqa: E402 from transfer_queue.metadata import ( # noqa: E402 BatchMeta, - FieldMeta, - SampleMeta, ) -from transfer_queue.utils.enum_utils import ProductionStatus, TransferQueueRole # noqa: E402 +from transfer_queue.utils.enum_utils import TransferQueueRole # noqa: E402 from transfer_queue.utils.zmq_utils import ( # noqa: E402 ZMQMessage, ZMQRequestType, @@ -175,24 +173,17 @@ def _mock_batch_meta(self, request_body): batch_size = request_body.get("batch_size", 1) data_fields = request_body.get("data_fields", []) - samples = [] - for i in range(batch_size): - fields = [] - for field_name in data_fields: - field_meta = FieldMeta( - name=field_name, - dtype=None, - shape=None, - production_status=ProductionStatus.NOT_PRODUCED, - ) - fields.append(field_meta) - sample = SampleMeta( - partition_id="0", - global_index=i, - fields={field.name: field for field in fields}, - ) - samples.append(sample) - metadata = BatchMeta(samples=samples) + # Build columnar field_schema + field_schema = { + field_name: {"dtype": None, "shape": None, "is_nested": False, "is_non_tensor": False} + for field_name in data_fields + } + + metadata = BatchMeta( + global_indexes=list(range(batch_size)), + partition_ids=["0"] * batch_size, + field_schema=field_schema, + ) return {"metadata": metadata} @@ -202,39 +193,31 @@ def _mock_kv_retrieve_meta(self, request_body): create = request_body.get("create", False) partition_id = request_body.get("partition_id", "") - # Initialize key tracking if not exists if not hasattr(self, "_kv_partition_keys"): self._kv_partition_keys = {} - # Generate global indexes for the keys start_index = self._get_next_kv_index(partition_id) global_indexes = list(range(start_index, start_index + len(keys))) - # Create metadata for each key - samples = [] - for i, key in enumerate(keys): - field_meta = FieldMeta( - name="data", - dtype=torch.float32, - shape=torch.Size([1, 10]), - production_status=ProductionStatus.READY_FOR_CONSUME, - ) - sample = SampleMeta( - partition_id=partition_id, - global_index=global_indexes[i], - fields={"data": field_meta}, - ) - samples.append(sample) - - metadata = BatchMeta(samples=samples) + # Build columnar BatchMeta for KV interface + field_schema = { + "data": {"dtype": "torch.float32", "shape": [1, 10], "is_nested": False, "is_non_tensor": False} + } + import numpy as np + + production_status = np.ones(len(global_indexes), dtype=np.int8) + metadata = BatchMeta( + global_indexes=global_indexes, + partition_ids=[partition_id] * len(global_indexes), + field_schema=field_schema, + production_status=production_status, + ) - # Store keys for this partition (only when create=True) if create: if partition_id not in self._kv_partition_keys: self._kv_partition_keys[partition_id] = [] self._kv_partition_keys[partition_id].extend(keys) - # Update the next index for this partition if global_indexes: self._update_kv_index(partition_id, global_indexes[-1] + 1) @@ -384,12 +367,12 @@ def _handle_data_requests(self): def _handle_get_data(self, request_body): """Handle GET_DATA request by retrieving stored data""" - local_indexes = request_body.get("local_indexes", []) + global_indexes = request_body.get("global_indexes", []) fields = request_body.get("fields", []) result: dict[str, list] = {} for field in fields: - gathered_items = [TEST_DATA[field][i] for i in local_indexes] + gathered_items = [TEST_DATA[field][i] for i in global_indexes] if gathered_items: all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items) @@ -847,7 +830,7 @@ async def test_async_clear_samples_with_empty_metadata(client_setup): client, _, _ = client_setup # Create empty BatchMeta - metadata = BatchMeta(samples=[]) + metadata = BatchMeta(global_indexes=[], partition_ids=[], field_schema={}) # The clear operation should complete without raising an exception # because the mock storage manager is configured to handle this diff --git a/tests/test_controller.py b/tests/test_controller.py index 77565bce..fc5b4100 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -29,7 +29,6 @@ logger = logging.getLogger(__name__) from transfer_queue.controller import TransferQueueController # noqa: E402 -from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 @pytest.fixture(scope="function") @@ -67,13 +66,9 @@ def test_controller_with_single_partition(self, ray_setup): ) assert metadata.global_indexes == list(range(gbs * num_n_samples)) - assert metadata.samples[0].partition_id == "train_0" - assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in metadata.samples]) == int( - ProductionStatus.NOT_PRODUCED - ) - assert sum([int(sample.fields.get("attention_mask").production_status) for sample in metadata.samples]) == int( - ProductionStatus.NOT_PRODUCED - ) + assert metadata.partition_ids[0] == "train_0" + # In insert mode, production_status should be all zeros (NOT_PRODUCED) + assert metadata.production_status is not None and all(metadata.production_status == 0) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id)) assert partition_index_range == list(range(gbs * num_n_samples)) @@ -158,7 +153,7 @@ def test_controller_with_single_partition(self, ray_setup): ) assert gen_meta.global_indexes == list(range(gbs * num_n_samples)) - assert gen_meta.samples[0].partition_id == "train_0" + assert gen_meta.partition_ids[0] == "train_0" assert gen_meta.field_names == ["prompt_ids"] partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) assert torch.equal(partition.consumption_status["generate_sequences"], torch.ones(gbs * num_n_samples)) @@ -187,7 +182,8 @@ def test_controller_with_single_partition(self, ray_setup): ) ) assert clear_meta.global_indexes == list(range(gbs * num_n_samples)) - assert [sample.fields for sample in clear_meta.samples] == [{}] * (gbs * num_n_samples) + # In insert mode with no fields, field_schema should be empty + assert clear_meta.field_schema == {} or clear_meta.field_names == [] print("✓ Clear metadata correct") # Test clear_partition @@ -456,13 +452,9 @@ def test_controller_with_multi_partitions(self, ray_setup): part1_index_range = gbs_1 * num_n_samples_1 part2_index_range = gbs_2 * num_n_samples_2 assert val_metadata.global_indexes == list(range(part1_index_range, part2_index_range + part1_index_range)) - assert val_metadata.samples[0].partition_id == "val_0" - assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in val_metadata.samples]) == int( - ProductionStatus.NOT_PRODUCED - ) - assert sum( - [int(sample.fields.get("attention_mask").production_status) for sample in val_metadata.samples] - ) == int(ProductionStatus.NOT_PRODUCED) + assert val_metadata.partition_ids[0] == "val_0" + # In insert mode, production_status should be all zeros (NOT_PRODUCED) + assert val_metadata.production_status is not None and all(val_metadata.production_status == 0) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2)) assert partition_index_range == list(range(part1_index_range, part2_index_range + part1_index_range)) @@ -536,13 +528,9 @@ def test_controller_with_multi_partitions(self, ray_setup): ) ) assert metadata_2.global_indexes == list(range(32)) + list(range(48, 80)) - assert metadata_2.samples[0].partition_id == "train_1" - assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in metadata_2.samples]) == int( - ProductionStatus.NOT_PRODUCED - ) - assert sum( - [int(sample.fields.get("attention_mask").production_status) for sample in metadata_2.samples] - ) == int(ProductionStatus.NOT_PRODUCED) + assert metadata_2.partition_ids[0] == "train_1" + # In insert mode, production_status should be all zeros (NOT_PRODUCED) + assert metadata_2.production_status is not None and all(metadata_2.production_status == 0) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3)) assert partition_index_range == list(range(32)) + list(range(48, 80)) print("✓ Correctly assign partition_3") @@ -884,12 +872,9 @@ def test_controller_kv_retrieve_meta_with_production_status(self, ray_setup): tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=False) ) - # Verify production status is available - assert len(retrieved_metadata.samples) == len(keys) - for sample in retrieved_metadata.samples: - assert "data" in sample.fields - assert sample.fields["data"].dtype == "torch.float32" - assert sample.fields["data"].shape == (64,) + # Verify production status is available (columnar API) + assert len(retrieved_metadata.global_indexes) == len(keys) + assert "data" in retrieved_metadata.field_schema print("✓ kv_retrieve_meta works with production status") diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index ec5b29c7..6320f231 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -25,33 +25,35 @@ parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 +from transfer_queue.metadata import BatchMeta # noqa: E402 from transfer_queue.storage.managers.base import KVStorageManager # noqa: E402 -from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 def get_meta(data, global_indexes=None): if not global_indexes: - global_indexes = range(data.batch_size[0]) - samples = [] - for sample_id in range(data.batch_size[0]): - fields_dict = {} - for field_name in data.keys(): - tensor = data[field_name][sample_id] - field_meta = FieldMeta( - name=field_name, - dtype=tensor.dtype if isinstance(tensor, torch.Tensor) else None, - shape=tensor.shape if isinstance(tensor, torch.Tensor) else None, - production_status=ProductionStatus.READY_FOR_CONSUME, - ) - fields_dict[field_name] = field_meta - sample = SampleMeta( - partition_id=0, - global_index=global_indexes[sample_id], - fields=fields_dict, - ) - samples.append(sample) - metadata = BatchMeta(samples=samples) + global_indexes = list(range(data.batch_size[0])) + + # Build columnar field_schema from the data + field_schema = {} + for field_name in data.keys(): + tensor = data[field_name][0] + field_schema[field_name] = { + "dtype": tensor.dtype if isinstance(tensor, torch.Tensor) else type(tensor), + "shape": tensor.shape if isinstance(tensor, torch.Tensor) else None, + "is_nested": False, + "is_non_tensor": not isinstance(tensor, torch.Tensor), + } + + import numpy as np + + production_status = np.ones(len(global_indexes), dtype=np.int8) + + metadata = BatchMeta( + global_indexes=list(global_indexes), + partition_ids=["0"] * len(global_indexes), + field_schema=field_schema, + production_status=production_status, + ) return metadata @@ -196,14 +198,13 @@ def test_get_shape_type_custom_backend_meta_list_without_custom_backend_meta(tes def test_get_shape_type_custom_backend_meta_list_with_custom_backend_meta(test_data): """Test _get_shape_type_custom_backend_meta_list returns correct custom_backend_meta when provided.""" - # Add custom_backend_meta to metadata - custom_backend_meta = { - 8: {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}}, - 9: {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}}, - 10: {"text": {"key7": "value7"}, "label": {"key8": "value8"}, "mask": {"key9": "value9"}}, - } + # Add custom_backend_meta to metadata (columnar: list aligned with global_indexes [8, 9, 10]) metadata = test_data["metadata"] - metadata._custom_backend_meta.update(custom_backend_meta) + metadata._custom_backend_meta = [ + {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # global_index=8 + {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}}, # global_index=9 + {"text": {"key7": "value7"}, "label": {"key8": "value8"}, "mask": {"key9": "value9"}}, # global_index=10 + ] shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) @@ -224,14 +225,13 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_backend_meta(test_d def test_get_shape_type_custom_backend_meta_list_with_partial_custom_backend_meta(test_data): """Test _get_shape_type_custom_backend_meta_list handles partial custom_backend_meta correctly.""" - # Add custom_backend_meta only for some global_indexes and fields - custom_backend_meta = { - 8: {"text": {"key1": "value1"}}, # Only text field - # global_index 9 has no custom_backend_meta - 10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only - } + # Add custom_backend_meta only for some fields (columnar: list aligned with global_indexes [8, 9, 10]) metadata = test_data["metadata"] - metadata._custom_backend_meta.update(custom_backend_meta) + metadata._custom_backend_meta = [ + {"text": {"key1": "value1"}}, # global_index=8: only text field + {}, # global_index=9: no custom_backend_meta + {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # global_index=10: label and mask only + ] shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 2a129b57..c6f7828b 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -13,1038 +13,373 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for TransferQueue metadata module - Learning Examples.""" +"""Unit tests for TransferQueue metadata module - Columnar BatchMeta + KVBatchMeta.""" import sys from pathlib import Path +import numpy as np import pytest import torch -from tensordict import TensorDict -from tensordict.tensorclass import NonTensorStack # Setup path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue.metadata import BatchMeta, FieldMeta, KVBatchMeta, SampleMeta # noqa: E402 -from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 +from transfer_queue.metadata import BatchMeta, KVBatchMeta # noqa: E402 +# ============================================================================== +# Columnar BatchMeta Tests +# ============================================================================== -class TestFieldMeta: - """FieldMeta learning examples.""" - def test_field_meta_is_ready(self): - """Test the is_ready property based on production status.""" - field_ready = FieldMeta( - name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME - ) - assert field_ready.is_ready is True - - field_not_ready = FieldMeta( - name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.NOT_PRODUCED - ) - assert field_not_ready.is_ready is False - - -class TestSampleMeta: - """SampleMeta learning examples.""" - - def test_sample_meta_union(self): - """Example: Union fields from two samples with matching global indexes.""" - # Create first sample - fields1 = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - sample1 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields1) +class TestBatchMetaColumnar: + """Columnar BatchMeta using field_schema + production_status (numpy array).""" - # Create second sample with additional fields - fields2 = { - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), + def _make_batch(self, batch_size=3, field_names=None): + """Helper: create a simple columnar BatchMeta.""" + if field_names is None: + field_names = ["field_a", "field_b"] + field_schema = { + field_name: {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False} + for field_name in field_names } - sample2 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields2) - - # Union samples - result = sample1.union(sample2) - - # Result contains all fields from both samples - assert "field1" in result.fields - assert "field2" in result.fields # From sample2 - assert "field3" in result.fields - - def test_sample_meta_union_validation_error(self): - """Example: Union validation catches mismatched global indexes.""" - sample1 = SampleMeta( - partition_id="partition_0", - global_index=0, - fields={"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,))}, - ) - - sample2 = SampleMeta( - partition_id="partition_0", - global_index=1, # Different global index - fields={"field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,))}, + production_status = np.ones(batch_size, dtype=np.int8) + return BatchMeta( + global_indexes=list(range(batch_size)), + partition_ids=["partition_0"] * batch_size, + field_schema=field_schema, + production_status=production_status, ) - with pytest.raises(ValueError) as exc_info: - sample1.union(sample2, validate=True) - assert "Global indexes" in str(exc_info.value) - - def test_sample_meta_add_fields(self): - """Example: Add new fields to a sample.""" - initial_fields = { - "field1": FieldMeta( - name="field1", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=initial_fields) - - new_fields = { - "field2": FieldMeta( - name="field2", dtype=torch.int64, shape=(3,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - sample.add_fields(new_fields) - - assert "field1" in sample.fields - assert "field2" in sample.fields - assert sample.is_ready is True - - def test_sample_meta_select_fields(self): - """Example: Select specific fields from a sample.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), - } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) - - # Select only field1 and field3 - selected_sample = sample.select_fields(["field1", "field3"]) - - assert "field1" in selected_sample.fields - assert "field3" in selected_sample.fields - assert "field2" not in selected_sample.fields - # Original sample is unchanged - assert len(sample.fields) == 3 - # Selected sample has correct metadata - assert selected_sample.fields["field1"].dtype == torch.float32 - assert selected_sample.fields["field1"].shape == (2,) - assert selected_sample.global_index == 0 - assert selected_sample.partition_id == "partition_0" - - def test_sample_meta_select_fields_with_nonexistent_fields(self): - """Example: Select fields ignores non-existent field names.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) - - # Try to select a field that doesn't exist - selected_sample = sample.select_fields(["field1", "nonexistent_field"]) - - # Only existing field is selected - assert "field1" in selected_sample.fields - assert "nonexistent_field" not in selected_sample.fields - assert "field2" not in selected_sample.fields - - def test_sample_meta_select_fields_empty_list(self): - """Example: Select with empty field list returns sample with no fields.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) - - # Select with empty list - selected_sample = sample.select_fields([]) - - assert len(selected_sample.fields) == 0 - assert selected_sample.global_index == 0 - assert selected_sample.partition_id == "partition_0" - - -class TestBatchMeta: - """BatchMeta learning examples - Core Operations.""" - - def test_batch_meta_chunk(self): - """Example: Split a batch into multiple chunks.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [SampleMeta(partition_id="partition_0", global_index=i, fields=fields) for i in range(10)] + def test_basic_init(self): + """Test basic columnar BatchMeta initialization.""" + batch = self._make_batch() + assert len(batch) == 3 + assert batch.global_indexes == [0, 1, 2] + assert batch.partition_ids == ["partition_0", "partition_0", "partition_0"] + assert "field_a" in batch.field_schema + assert "field_b" in batch.field_schema + assert batch.field_names == ["field_a", "field_b"] + + def test_production_status_vector(self): + """Test that production_status is accessible per sample.""" + batch = self._make_batch() + assert batch.production_status is not None + assert len(batch.production_status) == 3 + assert all(batch.production_status == 1) + + def test_chunk(self): + """Test splitting a batch into chunks.""" batch = BatchMeta( - samples=samples, - custom_meta={i: {"uid": i} for i in range(10)}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in range(10)}, + global_indexes=list(range(10)), + partition_ids=["partition_0"] * 10, + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(10, dtype=np.int8), + custom_meta=[{"uid": i} for i in range(10)], + _custom_backend_meta=[{"f": {"key": i}} for i in range(10)], ) - - # Chunk into 3 parts chunks = batch.chunk(3) - assert len(chunks) == 3 - assert len(chunks[0]) == 4 # First chunk gets extra element + # First chunk gets extra element (ceil division) + assert len(chunks[0]) == 4 assert len(chunks[1]) == 3 assert len(chunks[2]) == 3 - - # validate custom_meta is chunked - assert 0 in chunks[0].custom_meta - assert 1 in chunks[0].custom_meta - assert 2 in chunks[0].custom_meta - assert 3 in chunks[0].custom_meta - assert 4 not in chunks[0].custom_meta - assert 4 in chunks[1].custom_meta - - # validate _custom_backend_meta is chunked - assert 0 in chunks[0]._custom_backend_meta - assert 1 in chunks[0]._custom_backend_meta - assert 2 in chunks[0]._custom_backend_meta - assert 3 in chunks[0]._custom_backend_meta - assert 4 not in chunks[0]._custom_backend_meta - assert 4 in chunks[1]._custom_backend_meta - - def test_batch_meta_chunk_by_partition(self): - """Example: Split a batch into multiple chunks.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [SampleMeta(partition_id=f"partition_{i % 4}", global_index=i + 10, fields=fields) for i in range(10)] + # custom_meta is chunked correctly (positional) + assert chunks[0].custom_meta[0] == {"uid": 0} + assert chunks[0].custom_meta[3] == {"uid": 3} + assert len(chunks[0].custom_meta) == 4 + assert chunks[1].custom_meta[0] == {"uid": 4} + + def test_chunk_by_partition(self): + """Test splitting by partition_id.""" batch = BatchMeta( - samples=samples, - custom_meta={i + 10: {"uid": i + 10} for i in range(10)}, - _custom_backend_meta={i + 10: {"test_field": {"dtype": torch.float32}} for i in range(10)}, + global_indexes=[10, 11, 12, 13], + partition_ids=["part_A", "part_B", "part_A", "part_B"], + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, ) - - # Chunk according to partition_id chunks = batch.chunk_by_partition() - - assert len(chunks) == 4 - assert len(chunks[0]) == 3 - assert chunks[0].partition_ids == ["partition_0", "partition_0", "partition_0"] - assert chunks[0].global_indexes == [10, 14, 18] - assert len(chunks[1]) == 3 - assert chunks[1].partition_ids == ["partition_1", "partition_1", "partition_1"] - assert chunks[1].global_indexes == [11, 15, 19] - assert len(chunks[2]) == 2 - assert chunks[2].partition_ids == ["partition_2", "partition_2"] - assert chunks[2].global_indexes == [12, 16] - assert len(chunks[3]) == 2 - assert chunks[3].partition_ids == ["partition_3", "partition_3"] - assert chunks[3].global_indexes == [13, 17] - - # validate custom_meta is chunked - assert 10 in chunks[0].custom_meta - assert 14 in chunks[0].custom_meta - assert 18 in chunks[0].custom_meta - assert 11 not in chunks[0].custom_meta - assert 11 in chunks[1].custom_meta - - # validate _custom_backend_meta is chunked - assert 10 in chunks[0]._custom_backend_meta - assert 14 in chunks[0]._custom_backend_meta - assert 18 in chunks[0]._custom_backend_meta - assert 11 not in chunks[0]._custom_backend_meta - assert 11 in chunks[1]._custom_backend_meta - - def test_batch_meta_init_validation_error_different_field_names(self): - """Example: Init validation catches samples with different field names.""" - # Create first sample with field1 - fields1 = {"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,))} - sample1 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields1) - - # Create second sample with field2 - fields2 = {"field2": FieldMeta(name="field2", dtype=torch.float32, shape=(2,))} - sample2 = SampleMeta(partition_id="partition_0", global_index=1, fields=fields2) - - # Attempt to create BatchMeta with samples having different field names - with pytest.raises(ValueError) as exc_info: - BatchMeta(samples=[sample1, sample2]) - assert "All samples in BatchMeta must have the same field_names." in str(exc_info.value) - - def test_batch_meta_concat(self): - """Example: Concatenate multiple batches.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - # Create two batches - batch1 = BatchMeta( - samples=[ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ], - custom_meta={i: {"uid": i} for i in [0, 1]}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in [0, 1]}, - ) - + assert len(chunks) == 2 + part_ids = [c.partition_ids[0] for c in chunks] + assert "part_A" in part_ids + assert "part_B" in part_ids + + def test_concat(self): + """Test concatenating two batches.""" + batch1 = self._make_batch(batch_size=2) batch2 = BatchMeta( - samples=[ - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), - SampleMeta(partition_id="partition_0", global_index=3, fields=fields), - ], - custom_meta={i: {"uid": i} for i in [2, 3]}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in [2, 3]}, + global_indexes=[2, 3], + partition_ids=["partition_0", "partition_0"], + field_schema=batch1.field_schema, + production_status=np.ones(2, dtype=np.int8), ) - - # Concatenate batches result = BatchMeta.concat([batch1, batch2]) - assert len(result) == 4 assert result.global_indexes == [0, 1, 2, 3] - assert result.custom_meta == {i: {"uid": i} for i in [0, 1, 2, 3]} - assert result._custom_backend_meta == {i: {"test_field": {"dtype": torch.float32}} for i in [0, 1, 2, 3]} - - def test_batch_meta_concat_with_tensor_extra_info(self): - """Example: Concat handles tensor extra_info by concatenating along dim=0.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["tensor"] = torch.randn(3, 4) - batch1.extra_info["scalar"] = torch.tensor(1.0) + def test_custom_meta_update(self): + """Test update_custom_meta method.""" + batch = self._make_batch(batch_size=2) + batch.update_custom_meta([{"tag": "alpha"}, {"tag": "beta"}]) + assert batch.custom_meta[0]["tag"] == "alpha" + assert batch.custom_meta[1]["tag"] == "beta" + + def test_custom_backend_meta(self): + """Test _custom_backend_meta attribute.""" + batch = self._make_batch(batch_size=2) + batch._custom_backend_meta[0]["field_a"] = {"storage_key": "abc"} + assert batch._custom_backend_meta[0]["field_a"]["storage_key"] == "abc" + + def test_size_property(self): + """Test size == len property.""" + batch = self._make_batch(batch_size=5) + assert batch.size == 5 + assert len(batch) == 5 + + def test_add_fields_empty_batch_is_non_tensor_unknown(self): + """add_fields with empty field value leaves is_non_tensor as None (unknown). + + When a field has zero samples, we cannot determine the field type from data. + is_non_tensor must not default to False (which would incorrectly imply Tensor). + """ + from tensordict import TensorDict + + batch = BatchMeta.empty() + # TensorDict with an empty tensor of batch_size=0 + empty_td = TensorDict({"empty_field": torch.empty(0, 2)}, batch_size=0) + batch.add_fields(empty_td) + assert batch.field_schema["empty_field"]["is_non_tensor"] is None + + def test_pickle_roundtrip_preserves_batchmeta(self): + """BatchMeta must survive pickle round-trip with all fields intact.""" + import pickle - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["tensor"] = torch.randn(3, 4) - batch2.extra_info["scalar"] = torch.tensor(2.0) - - result = BatchMeta.concat([batch1, batch2]) - - # Tensors are concatenated along dim=0 - assert result.extra_info["tensor"].shape == (6, 4) - # Scalars are stacked - assert result.extra_info["scalar"].shape == (2,) - - def test_batch_meta_concat_with_non_tensor_stack(self): - """Example: Concat handles NonTensorStack extra_info.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["non_tensor"] = NonTensorStack(1, 2, 3) - - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["non_tensor"] = NonTensorStack(4, 5, 6) - - result = BatchMeta.concat([batch1, batch2]) - - # NonTensorStack is stacked - assert isinstance(result.extra_info["non_tensor"], NonTensorStack) - assert result.extra_info["non_tensor"].batch_size == torch.Size([2, 3]) - - def test_batch_meta_concat_with_list_extra_info(self): - """Example: Concat handles list extra_info by flattening.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["list"] = [1, 2, 3] - - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["list"] = [4, 5, 6] - - result = BatchMeta.concat([batch1, batch2]) - - # Lists are flattened - assert result.extra_info["list"] == [1, 2, 3, 4, 5, 6] - - def test_batch_meta_concat_with_mixed_types(self): - """Example: Concat handles mixed extra_info types correctly.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["tensor"] = torch.randn(3, 4) - batch1.extra_info["list"] = [1, 2, 3] - batch1.extra_info["string"] = "hello" - batch1.extra_info["non_tensor"] = NonTensorStack(1, 2, 3) - - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["tensor"] = torch.randn(3, 4) - batch2.extra_info["list"] = [4, 5] - batch2.extra_info["string"] = "world" - batch2.extra_info["non_tensor"] = NonTensorStack(4, 5, 6) + batch = BatchMeta( + global_indexes=[0, 1], + partition_ids=["p0", "p0"], + field_schema={ + "tensor_field": { + "dtype": torch.float32, + "shape": (4,), + "is_nested": False, + "is_non_tensor": False, + }, + "scalar_field": { + "dtype": torch.float32, + "shape": (), + "is_nested": False, + "is_non_tensor": False, + }, + }, + production_status=np.ones(2, dtype=np.int8), + extra_info={"step": 42}, + custom_meta=[{"score": 0.9}, {"score": 0.8}], + ) - result = BatchMeta.concat([batch1, batch2]) + data = pickle.dumps(batch) + restored = pickle.loads(data) - # Each type is handled appropriately - assert result.extra_info["tensor"].shape == (6, 4) # Concatenated - assert result.extra_info["list"] == [1, 2, 3, 4, 5] # Flattened - assert result.extra_info["string"] == "world" # Last value wins - assert isinstance(result.extra_info["non_tensor"], NonTensorStack) # Stacked - - def test_batch_meta_union(self): - """Example: Union two batches with matching global indexes.""" - fields1 = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - fields2 = { - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), - } + assert restored.global_indexes == batch.global_indexes + assert restored.partition_ids == batch.partition_ids + assert restored.field_schema["tensor_field"]["dtype"] == torch.float32 + assert restored.field_schema["scalar_field"]["shape"] == () + assert list(restored.production_status) == list(batch.production_status) + assert restored.extra_info == {"step": 42} + assert restored.custom_meta == [{"score": 0.9}, {"score": 0.8}] + def test_concat_extra_info_scalar_conflict_raises_value_error(self): + """concat raises ValueError when scalar extra_info values conflict.""" batch1 = BatchMeta( - samples=[ - SampleMeta(partition_id="partition_0", global_index=8, fields=fields1), - SampleMeta(partition_id="partition_0", global_index=9, fields=fields1), - ], - _custom_backend_meta={ - i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}} for i in [8, 9] - }, + global_indexes=[0], + partition_ids=["p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (1,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(1, dtype=np.int8), + extra_info={"step": 1}, ) - batch1.extra_info["info1"] = "value1" - batch2 = BatchMeta( - samples=[ - SampleMeta(partition_id="partition_0", global_index=8, fields=fields2), - SampleMeta(partition_id="partition_0", global_index=9, fields=fields2), - ], - _custom_backend_meta={ - i: {"field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} for i in [8, 9] - }, + global_indexes=[1], + partition_ids=["p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (1,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(1, dtype=np.int8), + extra_info={"step": 2}, ) - batch2.extra_info["info2"] = "value2" - - result = batch1.union(batch2) - - assert len(result) == 2 - # All fields are present - for sample in result.samples: - assert "field1" in sample.fields - assert "field2" in sample.fields - assert "field3" in sample.fields - # Extra info is merged - assert result.extra_info["info1"] == "value1" - assert result.extra_info["info2"] == "value2" - - # _custom_backend_meta is merged - assert result._custom_backend_meta == { - i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} - for i in [8, 9] - } - - def test_batch_meta_union_validation(self): - """Example: Union validation catches mismatched conditions.""" - fields = {"test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,))} - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) + with pytest.raises(ValueError, match="conflicting values"): + BatchMeta.concat([batch1, batch2]) + def test_concat_extra_info_key_union_with_warning(self): + """concat unions extra_info keys when sets differ, with a warning.""" + batch1 = BatchMeta( + global_indexes=[0], + partition_ids=["p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (1,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(1, dtype=np.int8), + extra_info={"common": "ok", "only_a": 1}, + ) batch2 = BatchMeta( - samples=[ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), # Different size - ] + global_indexes=[1], + partition_ids=["p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (1,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(1, dtype=np.int8), + extra_info={"common": "ok", "only_b": 2}, ) + result = BatchMeta.concat([batch1, batch2]) + assert result.extra_info["common"] == "ok" + assert result.extra_info["only_a"] == 1 + assert result.extra_info["only_b"] == 2 - with pytest.raises(ValueError) as exc_info: - batch1.union(batch2, validate=True) - assert "Batch sizes do not match" in str(exc_info.value) - - def test_batch_meta_reorder(self): - """Example: Reorder samples in a batch.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Reorder to [2, 0, 1] - batch.reorder([2, 0, 1]) - - assert batch.global_indexes == [6, 4, 5] - # Batch indexes are updated - assert batch.samples[0].batch_index == 0 - assert batch.samples[1].batch_index == 1 - assert batch.samples[2].batch_index == 2 - - def test_batch_meta_add_fields(self): - """Example: Add fields from TensorDict to all samples.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Create TensorDict with new fields - tensor_dict = TensorDict({"new_field1": torch.randn(2, 3), "new_field2": torch.randn(2, 5)}, batch_size=[2]) - - batch.add_fields(tensor_dict) - - # Fields are added to all samples - for sample in batch.samples: - assert "new_field1" in sample.fields - assert "new_field2" in sample.fields - assert sample.is_ready is True - - def test_batch_meta_select_fields(self): - """Example: Select specific fields from all samples in a batch.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] + def test_concat_extra_info_tensor_equal_preserved(self): + """concat preserves identical Tensor extra_info values.""" + t = torch.tensor([1.0, 2.0, 3.0]) + batch1 = BatchMeta( + global_indexes=[0], + partition_ids=["p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (1,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(1, dtype=np.int8), + extra_info={"embedding": t.clone()}, + ) + batch2 = BatchMeta( + global_indexes=[1], + partition_ids=["p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (1,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(1, dtype=np.int8), + extra_info={"embedding": t.clone()}, + ) + result = BatchMeta.concat([batch1, batch2]) + assert torch.equal(result.extra_info["embedding"], t) + + def test_setstate_readonly_production_status(self): + """__setstate__ must make read-only production_status writable. + + When Ray deserializes a BatchMeta via Arrow zero-copy, numpy arrays + become read-only. Since pickle skips __init__/__post_init__, the + .copy() guard is bypassed. __setstate__ must fix this. + """ + batch = self._make_batch() + # Simulate pickle round-trip with Arrow zero-copy (read-only array) + state = batch.__dict__.copy() + state["production_status"] = state["production_status"].copy() + state["production_status"].flags.writeable = False + + restored = BatchMeta.__new__(BatchMeta) + restored.__setstate__(state) + + # production_status must be writable after __setstate__ + assert restored.production_status.flags.writeable + # Verify add_fields works without ValueError + from tensordict import TensorDict + + td = TensorDict({"new_field": torch.randn(3, 4)}, batch_size=3) + restored.add_fields(td) # Should not raise + assert restored.is_ready + + def test_shallow_copy_isolation_global_indexes(self): + """Modifying the original global_indexes list does not affect BatchMeta.""" + original_indexes = [0, 1, 2] batch = BatchMeta( - samples=samples, - extra_info={"test_key": "test_value"}, - _custom_backend_meta={ - i: { - "field1": {"dtype": torch.float32}, - "field2": {"dtype": torch.int64}, - "field3": {"dtype": torch.bool}, - } - for i in [0, 1] - }, + global_indexes=original_indexes, + partition_ids=["p"] * 3, ) + original_indexes.append(99) + assert batch.global_indexes == [0, 1, 2] + assert len(batch) == 3 - # Select only field1 and field3 - selected_batch = batch.select_fields(["field1", "field3"]) - - # Check all samples have correct fields - assert len(selected_batch) == 2 - for sample in selected_batch.samples: - assert "field1" in sample.fields - assert "field3" in sample.fields - assert "field2" not in sample.fields - # Original batch is unchanged - assert len(batch.samples[0].fields) == 3 - # Extra info is preserved - assert selected_batch.extra_info["test_key"] == "test_value" - # Global indexes are preserved - assert selected_batch.global_indexes == [0, 1] - - # _custom_backend_meta is selected - assert "field1" in selected_batch._custom_backend_meta[0] - assert "field2" not in selected_batch._custom_backend_meta[0] - assert "field3" in selected_batch._custom_backend_meta[0] - assert "field1" in selected_batch._custom_backend_meta[1] - assert "field2" not in selected_batch._custom_backend_meta[1] - assert "field3" in selected_batch._custom_backend_meta[1] - - def test_batch_meta_select_fields_with_nonexistent_fields(self): - """Example: Select fields ignores non-existent field names in batch.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Try to select fields including non-existent ones - selected_batch = batch.select_fields(["field1", "nonexistent_field"]) - - # Only existing fields are selected - for sample in selected_batch.samples: - assert "field1" in sample.fields - assert "nonexistent_field" not in sample.fields - assert "field2" not in sample.fields - - def test_batch_meta_select_fields_empty_list(self): - """Example: Select with empty field list returns batch with no fields.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Select with empty list - selected_batch = batch.select_fields([]) - - assert len(selected_batch) == 2 - for sample in selected_batch.samples: - assert len(sample.fields) == 0 - # Global indexes are preserved - assert selected_batch.global_indexes == [0, 1] - - def test_batch_meta_select_fields_single_sample(self): - """Example: Select fields works correctly for batch with single sample.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) - batch = BatchMeta(samples=[sample]) - - # Select only field2 - selected_batch = batch.select_fields(["field2"]) - - assert len(selected_batch) == 1 - assert "field2" in selected_batch.samples[0].fields - assert "field1" not in selected_batch.samples[0].fields - - def test_batch_meta_select_fields_preserves_field_metadata(self): - """Example: Selected fields preserve their original metadata.""" - fields = { - "field1": FieldMeta( - name="field1", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME - ), - "field2": FieldMeta( - name="field2", dtype=torch.int64, shape=(5,), production_status=ProductionStatus.NOT_PRODUCED - ), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Select field1 - selected_batch = batch.select_fields(["field1"]) - selected_field = selected_batch.samples[0].fields["field1"] - - assert selected_field.dtype == torch.float32 - assert selected_field.shape == (2, 3) - assert selected_field.production_status == ProductionStatus.READY_FOR_CONSUME - assert selected_field.name == "field1" - - def test_batch_meta_select_samples(self): - """Example: Select specific samples from a batch.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), - SampleMeta(partition_id="partition_0", global_index=7, fields=fields), - ] - batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) - - # Select samples at indices [0, 2] - selected_batch = batch.select_samples([0, 2]) # This will select the first two samples with global_index=4/5 - - # Check number of samples - assert len(selected_batch) == 2 - # Check global indexes - assert selected_batch.global_indexes == [4, 6] - # Check fields are preserved - for sample in selected_batch.samples: - assert "field1" in sample.fields - assert "field2" in sample.fields - # Original batch is unchanged - assert len(batch) == 4 - # Extra info is preserved - assert selected_batch.extra_info["test_key"] == "test_value" - - def test_batch_meta_select_samples_all_indices(self): - """Example: Select all samples using complete index list.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), - ] - batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) - - # Select all samples - selected_batch = batch.select_samples([0, 1, 2]) - - # All samples are selected - assert len(selected_batch) == 3 - assert selected_batch.global_indexes == [4, 5, 6] - # Extra info is preserved - assert selected_batch.extra_info["test_key"] == "test_value" - - def test_batch_meta_select_samples_single_sample(self): - """Example: Select a single sample from batch.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Select only the middle sample - selected_batch = batch.select_samples([1]) - - assert len(selected_batch) == 1 - assert selected_batch.global_indexes == [1] - assert selected_batch.samples[0].batch_index == 0 # New batch index - - def test_batch_meta_select_samples_empty_list(self): - """Example: Select with empty list returns empty batch.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) - - # Select with empty list - selected_batch = batch.select_samples([]) - - assert len(selected_batch) == 0 - assert selected_batch.global_indexes == [] - # Extra info is still preserved - assert selected_batch.extra_info["test_key"] == "test_value" - - def test_batch_meta_select_samples_reverse_order(self): - """Example: Select samples in reverse order.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Select samples in reverse order - selected_batch = batch.select_samples([2, 1, 0]) - - assert len(selected_batch) == 3 - assert selected_batch.global_indexes == [2, 1, 0] - # Batch indexes are re-assigned - assert selected_batch.samples[0].global_index == 2 - assert selected_batch.samples[1].global_index == 1 - assert selected_batch.samples[2].global_index == 0 - - def test_batch_meta_select_samples_with_extra_info(self): - """Example: Select samples preserves all extra info types.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Add various extra info types - batch.extra_info["tensor"] = torch.randn(3, 4) - batch.extra_info["string"] = "test_string" - batch.extra_info["number"] = 42 - batch.extra_info["list"] = [1, 2, 3] - - # Select one sample - selected_batch = batch.select_samples([0]) - - # All extra info is preserved - assert "tensor" in selected_batch.extra_info - assert selected_batch.extra_info["string"] == "test_string" - assert selected_batch.extra_info["number"] == 42 - assert selected_batch.extra_info["list"] == [1, 2, 3] - - # ===================================================== - # Custom Meta Tests - # ===================================================== - def test_batch_meta_update_custom_meta(self): - """Test update_custom_meta adds metadata for different global indices.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - "field_b": FieldMeta(name="field_b", dtype=torch.int64, shape=(3,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Initial custom_meta for sample 0 - batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}]) - - result = batch.get_all_custom_meta() - assert result[0]["sample_score"] == 0.9 - assert result[1]["sample_score"] == 0.1 - - def test_batch_meta_update_custom_meta_overwrites(self): - """Test update_custom_meta overwrites existing metadata at same key.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Initial custom_meta - batch.update_custom_meta([{"sample_score": 0.9, "quality": "high"}]) - - # Update with new value for same field - dict.update replaces - batch.update_custom_meta([{"sample_score": 0.1, "quality": "low"}]) - - result = batch.get_all_custom_meta() - assert result[0]["sample_score"] == 0.1 - assert result[0]["quality"] == "low" - - def test_batch_meta_update_custom_meta_with_none(self): - """Test update_custom_meta with None does nothing.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Set initial value - batch.update_custom_meta([{"sample_score": 0.9}]) - - # Update with None should not change anything - batch.update_custom_meta(None) - - result = batch.get_all_custom_meta() - assert result[0]["sample_score"] == 0.9 - - def test_batch_meta_clear_custom_meta(self): - """Test clear_custom_meta removes all custom metadata.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Set custom_meta - batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}]) - - # Clear all - batch.clear_custom_meta() - - result = batch.get_all_custom_meta() - assert result == [{}, {}] - - def test_batch_meta_get_all_custom_meta_returns_deep_copy(self): - """Test get_all_custom_meta returns a deep copy.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - custom_meta = [{"sample_score": 0.9, "nested": {"value": 1}}] - batch.update_custom_meta(custom_meta) - - # Get all custom_meta - result = batch.get_all_custom_meta() - - # Verify it's a deep copy - modifying result should not affect original - result[0]["sample_score"] = 0.1 - result[0]["nested"]["value"] = 999 - - original = batch.get_all_custom_meta() - assert original[0]["sample_score"] == 0.9 - assert original[0]["nested"]["value"] == 1 - - def test_batch_meta_get_all_custom_meta_empty(self): - """Test get_all_custom_meta with no custom_meta returns empty dict.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - result = batch.get_all_custom_meta() - assert result == [{}] - - def test_batch_meta_custom_meta_with_nested_data(self): - """Test custom_meta supports nested dictionary data.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - nested_meta = { - "model_info": {"name": "llama", "version": "7b", "config": {"hidden_size": 4096, "num_layers": 32}}, - "tags": ["training", "inference"], - } - batch.update_custom_meta([nested_meta]) - - result = batch.get_all_custom_meta() - assert result[0]["model_info"]["name"] == "llama" - assert result[0]["model_info"]["version"] == "7b" - assert result[0]["model_info"]["config"]["hidden_size"] == 4096 - assert result[0]["tags"] == ["training", "inference"] - - # ===================================================== - # Extra Info Methods Tests - # ===================================================== - - def test_batch_meta_update_extra_info(self): - """Test update_extra_info adds multiple values.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - batch = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - - # Update with multiple values - batch.update_extra_info({"key1": "value1", "key2": "value2", "key3": "value3"}) - - # Verify all exist - assert "key1" in batch.extra_info - assert "key2" in batch.extra_info - assert "key3" in batch.extra_info - assert batch.extra_info["key1"] == "value1" - assert batch.extra_info["key2"] == "value2" - - def test_batch_meta_extra_info_preserved_in_operations(self): - """Test extra_info is preserved in batch operations.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["test_key1"] = "test_value" - - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["test_key2"] = "test_value_2" + def test_shallow_copy_isolation_extra_info(self): + """Modifying the original extra_info dict does not affect BatchMeta.""" + original_info = {"key": "value"} + batch = BatchMeta( + global_indexes=[0], + partition_ids=["p"], + extra_info=original_info, + ) + original_info["key"] = "corrupted" + original_info["new_key"] = "new" + assert batch.extra_info == {"key": "value"} + def test_shallow_copy_isolation_field_schema(self): + """Modifying the original field_schema dict does not affect BatchMeta.""" + original_schema = {"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}} + batch = BatchMeta( + global_indexes=[0], + partition_ids=["p"], + field_schema=original_schema, + ) + original_schema["f"]["dtype"] = torch.int64 + assert batch.field_schema["f"]["dtype"] == torch.float32 + + def test_select_fields_isolation_extra_info(self): + """select_fields result has isolated extra_info from the original.""" + batch = self._make_batch() + batch.set_extra_info("key", "original") + selected = batch.select_fields(["field_a"]) + selected.set_extra_info("key", "modified") + selected.set_extra_info("new_key", "new") + assert batch.extra_info["key"] == "original" + assert "new_key" not in batch.extra_info + + def test_select_fields_isolation_custom_meta(self): + """select_fields result has isolated custom_meta from the original.""" + batch = self._make_batch() + batch.update_custom_meta([{"score": 0.9}, {"score": 0.8}, {"score": 0.7}]) + selected = batch.select_fields(["field_a"]) + selected.update_custom_meta([{"score": 0.0}, {"score": 0.0}, {"score": 0.0}]) + assert batch.custom_meta[0]["score"] == 0.9 + + def test_concat_no_double_copy_regression(self): + """concat still works correctly after removing double-copy in __post_init__.""" + batch1 = self._make_batch(batch_size=2) + batch2 = BatchMeta( + global_indexes=[2, 3], + partition_ids=["partition_0", "partition_0"], + field_schema=batch1.field_schema, + production_status=np.ones(2, dtype=np.int8), + custom_meta=[{"id": 2}, {"id": 3}], + ) result = BatchMeta.concat([batch1, batch2]) + assert len(result) == 4 + assert result.global_indexes == [0, 1, 2, 3] + assert result.custom_meta[2] == {"id": 2} + assert result.custom_meta[3] == {"id": 3} - # Extra info is preserved - assert "test_key1" in result.extra_info - - def test_batch_meta_extra_info_with_concat(self): - """Test extra_info handling in concat with mixed types.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["string"] = "hello" - batch1.extra_info["number"] = 42 - - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["string"] = "world" - batch2.extra_info["number"] = 100 - + def test_concat_extra_info_identical_scalars_preserved(self): + """concat preserves identical scalar extra_info (int, str, dict).""" + common_info = {"step": 42, "mode": "train", "config": {"lr": 0.01}} + batch1 = BatchMeta( + global_indexes=[0], + partition_ids=["p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (1,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(1, dtype=np.int8), + extra_info=dict(common_info), + ) + batch2 = BatchMeta( + global_indexes=[1], + partition_ids=["p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (1,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(1, dtype=np.int8), + extra_info=dict(common_info), + ) result = BatchMeta.concat([batch1, batch2]) + assert result.extra_info == common_info + assert len(result) == 2 - # String: last value wins - assert result.extra_info["string"] == "world" - - -class TestEdgeCases: - """Edge cases and important boundaries.""" - - def test_batch_meta_chunk_with_more_chunks_than_samples(self): - """Example: Chunking when chunks > samples produces empty chunks.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # 5 chunks for 2 samples - chunks = batch.chunk(5) - - assert len(chunks) == 5 - # First 2 chunks have samples - assert len(chunks[0]) == 1 - assert len(chunks[1]) == 1 - # Last 3 chunks are empty - assert len(chunks[2]) == 0 - assert len(chunks[3]) == 0 - assert len(chunks[4]) == 0 - - def test_batch_meta_concat_with_empty_batches(self): - """Example: Concat handles empty batches gracefully.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - batch1 = BatchMeta(samples=[]) - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch3 = BatchMeta(samples=[]) - - # Empty batches are filtered out - result = BatchMeta.concat([batch1, batch2, batch3]) - assert len(result) == 1 - assert result.global_indexes == [0] - - def test_batch_meta_concat_validation_error(self): - """Example: Concat validation catches field name mismatches.""" - fields1 = {"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,))} - fields2 = {"field2": FieldMeta(name="field2", dtype=torch.float32, shape=(2,))} - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields1)]) + def test_chunk_concat_roundtrip_preserves_extra_info(self): + """chunk followed by concat preserves extra_info without errors.""" + batch = BatchMeta( + global_indexes=list(range(6)), + partition_ids=["p0"] * 6, + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(6, dtype=np.int8), + extra_info={"metrics": {"loss": 0.5}, "step": 100, "tags": ["train"]}, + ) + chunks = batch.chunk(3) + restored = BatchMeta.concat(chunks) + assert restored.extra_info == {"metrics": {"loss": 0.5}, "step": 100, "tags": ["train"]} + assert len(restored) == 6 + assert restored.global_indexes == list(range(6)) - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields2)]) - with pytest.raises(ValueError) as exc_info: - BatchMeta.concat([batch1, batch2], validate=True) - assert "Field names do not match" in str(exc_info.value) +# ============================================================================== +# KVBatchMeta Tests (all migrated from main with no modification) +# ============================================================================== class TestKVBatchMeta: @@ -1239,7 +574,7 @@ def test_kv_batch_meta_concat(self): tags=[{"idx": 2}, {"idx": 3}], partition_id="partition_0", fields=["field1"], - extra_info={"test": "value2"}, + extra_info={"test": "value1"}, ) result = KVBatchMeta.concat([kv_meta1, kv_meta2]) @@ -1332,3 +667,36 @@ def test_kv_batch_meta_deepcopy_extra_info(self): # Original should not be modified assert original_extra["nested"]["value"] == 1 + + def test_kv_batch_meta_concat_extra_info_conflict_raises(self): + """KVBatchMeta.concat raises ValueError on conflicting extra_info values.""" + kv1 = KVBatchMeta( + keys=["k0"], + tags=[{}], + extra_info={"step": 1}, + ) + kv2 = KVBatchMeta( + keys=["k1"], + tags=[{}], + extra_info={"step": 2}, + ) + with pytest.raises(ValueError, match="conflicting"): + KVBatchMeta.concat([kv1, kv2]) + + +# ============================================================================== +# StorageUnitData Tests +# ============================================================================== + + +class TestStorageUnitDataStrict: + """Tests for StorageUnitData length validation.""" + + def test_put_data_length_mismatch_raises(self): + """put_data must raise when global_indexes and field values have different lengths.""" + from transfer_queue.storage.simple_backend import StorageUnitData + + sud = StorageUnitData(storage_size=10) + # 3 indexes but only 2 values — must raise, not silently drop + with pytest.raises(ValueError, match="length mismatch"): + sud.put_data({"field_a": [1, 2]}, global_indexes=[0, 1, 2]) diff --git a/tests/test_ray_p2p.py b/tests/test_ray_p2p.py index 4bd54a7b..e958b84c 100644 --- a/tests/test_ray_p2p.py +++ b/tests/test_ray_p2p.py @@ -17,6 +17,7 @@ import time from pathlib import Path +import numpy as np import ray import torch from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -26,7 +27,7 @@ sys.path.append(str(parent_dir)) from transfer_queue.client import TransferQueueClient # noqa: E402 -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 +from transfer_queue.metadata import BatchMeta # noqa: E402 from transfer_queue.storage.managers.base import KVStorageManager # noqa: E402 from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory # noqa: E402 from transfer_queue.utils.zmq_utils import ZMQServerInfo # noqa: E402 @@ -115,17 +116,19 @@ def generate_data( batch_size=batch_size, ) - samples = [ - SampleMeta( - global_index=i, - partition_id=partition_id, - fields={ - "input_ids": FieldMeta(name="input_ids", dtype=torch.float32, shape=(seq_len,)), - }, - ) - for i in range(batch_size) - ] - meta = BatchMeta(samples=samples) + meta = BatchMeta( + global_indexes=list(range(batch_size)), + partition_ids=[partition_id] * batch_size, + field_schema={ + "input_ids": { + "dtype": torch.float32, + "shape": (seq_len,), + "is_nested": False, + "is_non_tensor": False, + } + }, + production_status=np.zeros(batch_size, dtype=np.int8), + ) self.data = data self.meta = meta diff --git a/tests/test_serial_utils_on_cpu.py b/tests/test_serial_utils_on_cpu.py index 316bbd92..9a0ec45a 100644 --- a/tests/test_serial_utils_on_cpu.py +++ b/tests/test_serial_utils_on_cpu.py @@ -16,6 +16,7 @@ import sys from pathlib import Path +import numpy as np import pytest import torch from tensordict import TensorDict @@ -76,6 +77,8 @@ def test_zmq_msg_serialization(): encoded_msg = msg.serialize() decoded_msg = ZMQMessage.deserialize(encoded_msg) assert decoded_msg.request_type == msg.request_type + # TensorDict converts numpy arrays to Tensors on insertion, + # so decoding yields a Tensor (not np.ndarray). assert torch.allclose(decoded_msg.body["data"]["numpy_array"], msg.body["data"]["numpy_array"]) assert torch.allclose(decoded_msg.body["data"]["normal_tensor"], msg.body["data"]["normal_tensor"]) assert msg.body["data"]["nested_tensor"].layout == decoded_msg.body["data"]["nested_tensor"].layout @@ -823,20 +826,23 @@ def worker(thread_id: int) -> tuple[int, list[str]]: # ============================================================================ -# Numpy Array Type Compatibility Tests +# Numpy Serialization Tests # ============================================================================ -class TestNumpyArrayTypeCompatibility: +class TestNumpySerialization: """Test numpy array serialization with various dtypes. - These tests verify the fix for the TypeError when using torch.from_numpy() - with unsupported numpy dtypes (e.g., object arrays). The fix uses pickle - fallback for incompatible types while maintaining zero-copy for numeric types. + These tests verify: + 1. The fix for the TypeError when using torch.from_numpy() with unsupported + numpy dtypes (e.g., object arrays). The fix uses pickle fallback for + incompatible types while maintaining zero-copy for numeric types. + 2. Numeric numpy arrays round-trip as np.ndarray (not torch.Tensor), + preserving dtype and shape exactly, using zero-copy path. """ + # --- Object / string array tests (formerly TestNumpyArrayTypeCompatibility) --- + def test_numpy_object_array_strings(self): """Test numpy object array with string elements.""" - import numpy as np - encoder = MsgpackEncoder() decoder = MsgpackDecoder() @@ -851,8 +857,6 @@ def test_numpy_object_array_strings(self): def test_numpy_object_array_mixed_types(self): """Test numpy object array with mixed Python types.""" - import numpy as np - encoder = MsgpackEncoder() decoder = MsgpackDecoder() @@ -867,8 +871,6 @@ def test_numpy_object_array_mixed_types(self): def test_numpy_object_array_dicts(self): """Test numpy object array containing Python dicts.""" - import numpy as np - encoder = MsgpackEncoder() decoder = MsgpackDecoder() @@ -883,13 +885,10 @@ def test_numpy_object_array_dicts(self): assert orig == decoded def test_numpy_numeric_arrays_zero_copy(self): - """Test that numeric numpy arrays use zero-copy path.""" - import numpy as np - + """Test that numeric numpy arrays use zero-copy path and return np.ndarray.""" encoder = MsgpackEncoder() decoder = MsgpackDecoder() - # These should use zero-copy (torch.from_numpy + tensor encoding) numeric_dtypes = [ np.float32, np.float64, @@ -910,19 +909,20 @@ def test_numpy_numeric_arrays_zero_copy(self): serialized = encoder.encode(arr) - # Zero-copy should produce multiple buffers (metadata + tensor buffer) + # Zero-copy must produce multiple buffers (metadata + data buffer) assert len(serialized) > 1, f"Expected zero-copy for dtype {dtype}" deserialized = decoder.decode(serialized) - # Deserialized as torch.Tensor (due to zero-copy path) - assert isinstance(deserialized, torch.Tensor) - assert torch.allclose(deserialized, torch.from_numpy(arr)) + # After the fix: deserialized must be np.ndarray, not torch.Tensor + assert isinstance(deserialized, np.ndarray), ( + f"Expected np.ndarray but got {type(deserialized)} for dtype={dtype}" + ) + assert deserialized.dtype == arr.dtype + assert np.array_equal(deserialized, arr) def test_numpy_object_array_in_zmq_message(self): """Test numpy object array inside ZMQMessage.""" - import numpy as np - from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType # Create message with both object array and regular tensors @@ -951,8 +951,6 @@ def test_numpy_object_array_in_zmq_message(self): def test_numpy_unicode_string_array(self): """Test numpy unicode string array (dtype=' 1.""" + encoder = MsgpackEncoder() + arr = np.arange(100, dtype=np.float32) + serialized = encoder.encode(arr) + assert len(serialized) > 1, "Expected zero-copy (aux buffer) for float32 ndarray" + + def test_numpy_non_contiguous_roundtrip(self): + """Non-C-contiguous arrays must be made contiguous before serialization.""" + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() + + base = np.arange(100, dtype=np.float64).reshape(10, 10) + arr = base[::2, ::2] # non-contiguous view + assert not arr.flags["C_CONTIGUOUS"] + + serialized = encoder.encode(arr) + deserialized = decoder.decode(serialized) + + assert isinstance(deserialized, np.ndarray) + assert np.array_equal(deserialized, arr) + + def test_numpy_multidim_shape_preserved(self): + """Shape must survive a round-trip for multi-dimensional arrays.""" + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() + + arr = np.arange(60, dtype=np.int32).reshape(3, 4, 5) + serialized = encoder.encode(arr) + deserialized = decoder.decode(serialized) + + assert isinstance(deserialized, np.ndarray) + assert deserialized.shape == (3, 4, 5) + assert np.array_equal(deserialized, arr) + + def test_numpy_empty_array_roundtrip(self): + """Empty arrays must round-trip correctly.""" + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() + + arr = np.empty((0,), dtype=np.float32) + serialized = encoder.encode(arr) + deserialized = decoder.decode(serialized) + + assert isinstance(deserialized, np.ndarray) + assert deserialized.shape == (0,) + assert deserialized.dtype == np.float32 + + def test_numpy_object_array_still_uses_pickle(self): + """Object arrays (kind='O' or hasobject) must fall back to pickle.""" + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() + + # dtype=object — kind 'O', cannot be viewed as a contiguous byte buffer + arr = np.array(["a", "b", "c"], dtype=object) + serialized = encoder.encode(arr) + + # Pickle-fallback produces a single buffer (no aux tensor buffer appended) + assert len(serialized) == 1, "Object array should not use zero-copy path" + + deserialized = decoder.decode(serialized) + assert isinstance(deserialized, np.ndarray) + assert np.array_equal(deserialized, arr) diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index ed43e41e..b18f8dd1 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -40,29 +40,29 @@ def __init__(self, storage_put_get_address): self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout self.socket.connect(storage_put_get_address) - def send_put(self, client_id, local_indexes, field_data): + def send_put(self, client_id, global_indexes, field_data): msg = ZMQMessage.create( request_type=ZMQRequestType.PUT_DATA, sender_id=f"mock_client_{client_id}", - body={"local_indexes": local_indexes, "data": field_data}, + body={"global_indexes": global_indexes, "data": field_data}, ) self.socket.send_multipart(msg.serialize()) return ZMQMessage.deserialize(self.socket.recv_multipart()) - def send_get(self, client_id, local_indexes, fields): + def send_get(self, client_id, global_indexes, fields): msg = ZMQMessage.create( request_type=ZMQRequestType.GET_DATA, sender_id=f"mock_client_{client_id}", - body={"local_indexes": local_indexes, "fields": fields}, + body={"global_indexes": global_indexes, "fields": fields}, ) self.socket.send_multipart(msg.serialize()) return ZMQMessage.deserialize(self.socket.recv_multipart()) - def send_clear(self, client_id, local_indexes): + def send_clear(self, client_id, global_indexes): msg = ZMQMessage.create( request_type=ZMQRequestType.CLEAR_DATA, sender_id=f"mock_client_{client_id}", - body={"local_indexes": local_indexes}, + body={"global_indexes": global_indexes}, ) self.socket.send_multipart(msg.serialize()) return ZMQMessage.deserialize(self.socket.recv_multipart()) @@ -107,13 +107,13 @@ def test_put_get_single_client(storage_setup): client = MockStorageClient(put_get_address) # PUT data - local_indexes = [0, 1, 2] + global_indexes = [0, 1, 2] field_data = { "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])], "rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])], } - response = client.send_put(0, local_indexes, field_data) + response = client.send_put(0, global_indexes, field_data) assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE # GET data @@ -142,9 +142,9 @@ def test_put_get_multiple_clients(storage_setup): num_clients = 3 clients = [MockStorageClient(put_get_address) for _ in range(num_clients)] - # Each client puts unique data using different local_indexes + # Each client puts unique data using different global_indexes for i, client in enumerate(clients): - local_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2] + global_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2] field_data = { "log_probs": [ torch.tensor([i, i + 1, i + 2]), @@ -154,14 +154,14 @@ def test_put_get_multiple_clients(storage_setup): "rewards": [torch.tensor([i * 10]), torch.tensor([i * 10 + 10]), torch.tensor([i * 10 + 20])], } - response = client.send_put(i, local_indexes, field_data) + response = client.send_put(i, global_indexes, field_data) assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - # Test overlapping local indexes + # Test overlapping global indexes overlapping_client = MockStorageClient(put_get_address) - overlap_local_indexes = [0] # Overlaps with first client's index 0 + overlap_global_indexes = [0] # Overlaps with first client's index 0 overlap_field_data = {"log_probs": [torch.tensor([999, 999, 999])], "rewards": [torch.tensor([999])]} - response = overlapping_client.send_put(99, overlap_local_indexes, overlap_field_data) + response = overlapping_client.send_put(99, overlap_global_indexes, overlap_field_data) assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE # Each original client gets its own data (except for index 0 which was overwritten) @@ -209,7 +209,7 @@ def test_performance_basic(storage_setup): start = time.time() # Use batch size and index mapping - local_indexes = list(range(i * batch_size, (i + 1) * batch_size)) + global_indexes = list(range(i * batch_size, (i + 1) * batch_size)) # Create tensor data log_probs_data = [] @@ -224,7 +224,7 @@ def test_performance_basic(storage_setup): field_data = {"log_probs": log_probs_data, "rewards": rewards_data} - response = client.send_put(0, local_indexes, field_data) + response = client.send_put(0, global_indexes, field_data) latency = time.time() - start put_latencies.append(latency) assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE @@ -236,8 +236,8 @@ def test_performance_basic(storage_setup): for i in range(num_gets): start = time.time() # Retrieve batch of data - local_indexes = list(range(i * batch_size, (i + 1) * batch_size)) - response = client.send_get(0, local_indexes, ["log_probs", "rewards"]) + global_indexes = list(range(i * batch_size, (i + 1) * batch_size)) + response = client.send_get(0, global_indexes, ["log_probs", "rewards"]) latency = time.time() - start get_latencies.append(latency) assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE @@ -259,7 +259,7 @@ def test_put_get_nested_tensor(storage_setup): client = MockStorageClient(put_get_address) # PUT data with nested tensors - local_indexes = [0, 1, 2] + global_indexes = [0, 1, 2] field_data = { "variable_length_sequences": [ torch.tensor([-0.5, -1.2, -0.8]), @@ -269,7 +269,7 @@ def test_put_get_nested_tensor(storage_setup): "attention_mask": [torch.tensor([1, 1, 1]), torch.tensor([1, 1, 1, 1]), torch.tensor([1, 1])], } - response = client.send_put(0, local_indexes, field_data) + response = client.send_put(0, global_indexes, field_data) assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE # GET data @@ -298,13 +298,13 @@ def test_put_get_non_tensor_data(storage_setup): client = MockStorageClient(put_get_address) # PUT data with non-tensor data - local_indexes = [0, 1, 2] + global_indexes = [0, 1, 2] field_data = { "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"], "response_text": ["Hi there!", "This is the response to the longer sentence", "Test response"], } - response = client.send_put(0, local_indexes, field_data) + response = client.send_put(0, global_indexes, field_data) assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE # GET data @@ -366,13 +366,13 @@ def test_clear_data(storage_setup): client = MockStorageClient(put_get_address) # PUT data first - local_indexes = [0, 1, 2] + global_indexes = [0, 1, 2] field_data = { "log_probs": [torch.tensor([1.0]), torch.tensor([2.0]), torch.tensor([3.0])], "rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])], } - response = client.send_put(0, local_indexes, field_data) + response = client.send_put(0, global_indexes, field_data) assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE # Verify data exists @@ -399,25 +399,44 @@ def test_storage_unit_data_direct(): storage_data = StorageUnitData(storage_size=10) - # Test put_data field_data = { "log_probs": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])], "rewards": [torch.tensor([10.0]), torch.tensor([20.0])], } + # global_indexes = global_index values (e.g., 0 and 1) storage_data.put_data(field_data, [0, 1]) - # Test get_data result = storage_data.get_data(["log_probs", "rewards"], [0, 1]) assert "log_probs" in result assert "rewards" in result assert len(result["log_probs"]) == 2 assert len(result["rewards"]) == 2 - # Test single index get result_single = storage_data.get_data(["log_probs"], [0]) - assert torch.allclose(result_single["log_probs"][0], torch.tensor([1.0, 2.0])) + torch.testing.assert_close(result_single["log_probs"][0], torch.tensor([1.0, 2.0])) - # Test clear + # clear: key is removed (not set to None) storage_data.clear([0]) - result_after_clear = storage_data.get_data(["log_probs"], [0]) - assert result_after_clear["log_probs"][0] is None + assert 0 not in storage_data.field_data["log_probs"] # key gone + assert 1 in storage_data.field_data["log_probs"] # other key intact + + +def test_storage_unit_data_capacity_uses_active_keys(): + """Capacity check must use _active_keys, not scan field_data.""" + from transfer_queue.storage.simple_backend import StorageUnitData + + storage = StorageUnitData(storage_size=3) + + # Fill to capacity + storage.put_data({"f": [1, 2, 3]}, global_indexes=[0, 1, 2]) + assert len(storage._active_keys) == 3 + + # Exceeding capacity must raise + with pytest.raises(ValueError, match="Storage capacity exceeded"): + storage.put_data({"f": [4]}, global_indexes=[3]) + + # After clearing one key, adding one more should succeed + storage.clear(keys=[2]) + assert len(storage._active_keys) == 2 + storage.put_data({"f": [4]}, global_indexes=[3]) + assert storage._active_keys == {0, 1, 3} diff --git a/tests/test_yuanrong_client_zero_copy.py b/tests/test_yuanrong_client_zero_copy.py index 3048ec57..b93fd32a 100644 --- a/tests/test_yuanrong_client_zero_copy.py +++ b/tests/test_yuanrong_client_zero_copy.py @@ -21,6 +21,8 @@ import pytest import torch +pytest.importorskip("yr") + parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 2090ad1a..199ceeda 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -238,8 +238,7 @@ async def async_get_meta( ) if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE: - metadata_dict = response_msg.body["metadata"] - return BatchMeta.from_dict(metadata_dict) if isinstance(metadata_dict, dict) else metadata_dict + return response_msg.body["metadata"] else: raise RuntimeError( f"[{self.client_id}]: Failed to get metadata from controller {self._controller.id}: " @@ -578,11 +577,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta if response_msg.request_type != ZMQRequestType.GET_PARTITION_META_RESPONSE: raise RuntimeError("Failed to get metadata for clear operation.") - return ( - BatchMeta.from_dict(response_msg.body["metadata"]) - if isinstance(response_msg.body["metadata"], dict) - else response_msg.body["metadata"] - ) + return response_msg.body["metadata"] @dynamic_socket(socket_name="request_handle_socket") async def _clear_partition_in_controller(self, partition_id, socket=None): @@ -971,7 +966,6 @@ async def async_kv_retrieve_meta( if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_META_RESPONSE: metadata = response_msg.body.get("metadata", BatchMeta.empty()) - metadata = BatchMeta.from_dict(metadata) if isinstance(metadata, dict) else metadata return metadata else: raise RuntimeError( diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 50ff10df..bc070580 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -25,6 +25,7 @@ from typing import Any, Optional from uuid import uuid4 +import numpy as np import ray import torch import zmq @@ -33,11 +34,9 @@ from transfer_queue.metadata import ( BatchMeta, - FieldMeta, - SampleMeta, ) from transfer_queue.sampler import BaseSampler, SequentialSampler -from transfer_queue.utils.enum_utils import ProductionStatus, TransferQueueRole +from transfer_queue.utils.enum_utils import TransferQueueRole from transfer_queue.utils.perf_utils import IntervalPerfMonitor from transfer_queue.utils.zmq_utils import ( ZMQMessage, @@ -228,6 +227,9 @@ class DataPartitionStatus: field_name_mapping: dict[str, int] = field(default_factory=dict) # field_name -> column_index field_dtypes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: dtype} field_shapes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: shape} + # O(F) schema cache: field_name -> {dtype, shape, is_nested, is_non_tensor} + # Updated eagerly in _update_field_metadata; used by get_field_schema() for O(1) per-field lookup. + field_schema_cache: dict[str, dict[str, Any]] = field(default_factory=dict) field_custom_backend_meta: dict[int, dict[str, Any]] = field( default_factory=dict ) # global_idx -> {field: custom_backend_meta} @@ -484,12 +486,50 @@ def _update_field_metadata( if global_idx not in self.field_dtypes: self.field_dtypes[global_idx] = {} self.field_dtypes[global_idx].update(dtype_value[i]) + # Update field_schema_cache with new dtype info + for field_name, dtype in dtype_value[i].items(): + if field_name not in self.field_schema_cache: + self.field_schema_cache[field_name] = { + "dtype": dtype, + "shape": None, + "is_nested": False, + "is_non_tensor": False, + } + else: + cached = self.field_schema_cache[field_name] + if cached.get("dtype") is None and not cached.get("_dtype_mixed"): + cached["dtype"] = dtype + elif cached.get("dtype") is not None and cached["dtype"] != dtype: + logger.warning( + f"Field '{field_name}' dtype changed from " + f"{cached['dtype']} to {dtype} at global_index " + f"{global_idx}. Setting cached dtype to None." + ) + cached["dtype"] = None + cached["_dtype_mixed"] = True # Only create and update shape mapping if a shape value was provided if shape_value[i] is not None: if global_idx not in self.field_shapes: self.field_shapes[global_idx] = {} self.field_shapes[global_idx].update(shape_value[i]) + # Update field_schema_cache with new shape info + for field_name, shape in shape_value[i].items(): + if field_name not in self.field_schema_cache: + self.field_schema_cache[field_name] = { + "dtype": None, + "shape": shape, + "is_nested": False, + "is_non_tensor": False, + } + else: + cached = self.field_schema_cache[field_name] + if cached.get("shape") is None and not cached.get("is_nested"): + cached["shape"] = shape + elif cached.get("shape") is not None and cached["shape"] != shape: + # Shapes differ across samples → mark as nested + cached["is_nested"] = True + cached["shape"] = None # Only create and update custom_backend_meta mapping if a custom_backend_meta value was provided if custom_backend_meta_value[i] is not None: @@ -670,13 +710,23 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: # ==================== Metadata Methods ==================== - def get_field_dtype(self, global_index: int, field_name: str) -> Optional[Any]: - """Get dtype for a specific sample and field.""" - return self.field_dtypes.get(global_index, {}).get(field_name) + def get_field_schema(self, field_names: list[str]) -> dict[str, dict[str, Any]]: + """Return field_schema for the requested fields from the O(F) cache. - def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]: - """Get shape for a specific sample and field.""" - return self.field_shapes.get(global_index, {}).get(field_name) + Complexity: O(F) — one dict-lookup per field, no full scan of per-sample maps. + The cache is populated eagerly in _update_field_metadata() at put time. + """ + schema = {} + for field_name in field_names: + cached = self.field_schema_cache.get(field_name) + if cached is not None: + schema[field_name] = { + "dtype": cached.get("dtype"), + "shape": cached.get("shape"), + "is_nested": cached.get("is_nested", False), + "is_non_tensor": cached.get("is_non_tensor", False), + } + return schema def get_field_custom_backend_meta( self, global_indices: list[int], field_names: list[str] @@ -706,6 +756,14 @@ def get_field_custom_backend_meta( if idx in self.field_custom_backend_meta } + def get_field_dtype(self, global_index: int, field_name: str) -> Optional[Any]: + """Get the dtype for a specific (global_index, field_name) pair.""" + return self.field_dtypes.get(global_index, {}).get(field_name, None) + + def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]: + """Get the shape for a specific (global_index, field_name) pair.""" + return self.field_shapes.get(global_index, {}).get(field_name, None) + def get_custom_meta(self, global_indices: list[int]) -> dict[int, dict]: """ Get custom_meta for multiple samples. @@ -1298,6 +1356,8 @@ def generate_batch_meta( """ Generate BatchMeta for specific samples in a partition. + O(F) optimized version that uses field_schema instead of per-sample metadata. + This function is responsible only for metadata generation and does not modify consumption state. State management is handled by the calling function. @@ -1320,55 +1380,63 @@ def generate_batch_meta( if mode not in ["fetch", "insert", "force_fetch"]: raise ValueError(f"Invalid mode: {mode}") - # Generate sample metadata - samples = [] - for global_index in batch_global_indexes: - fields = {} - for field_name in data_fields: - # Determine production status - if mode == "fetch": - production_status = ProductionStatus.READY_FOR_CONSUME - dtype = partition.get_field_dtype(global_index, field_name) - shape = partition.get_field_shape(global_index, field_name) - elif mode == "insert": - production_status = ProductionStatus.NOT_PRODUCED - dtype = None - shape = None - elif mode == "force_fetch": - field_index = partition.field_name_mapping.get(field_name) - if ( - field_index is not None - and partition.production_status is not None - and partition.production_status[global_index, field_index] == 1 - ): - production_status = ProductionStatus.READY_FOR_CONSUME - dtype = partition.get_field_dtype(global_index, field_name) - shape = partition.get_field_shape(global_index, field_name) - else: - production_status = ProductionStatus.NOT_PRODUCED - dtype = None - shape = None - - fields[field_name] = FieldMeta( - name=field_name, - dtype=dtype, - shape=shape, - production_status=production_status, - ) + batch_size = len(batch_global_indexes) - sample = SampleMeta( - partition_id=partition_id, - global_index=global_index, - fields=fields, - ) - samples.append(sample) + field_schema = partition.get_field_schema(data_fields) + + # For nested fields, populate per_sample_shapes from per-sample field_shapes + # so that downstream consumers (e.g. _get_shape_type_custom_backend_meta_list) + # can reconstruct tensors with correct per-sample dimensions. + for field_name, meta in field_schema.items(): + if meta.get("is_nested"): + meta["per_sample_shapes"] = [partition.get_field_shape(gi, field_name) for gi in batch_global_indexes] - custom_meta = partition.get_custom_meta(batch_global_indexes) + # In insert mode, create placeholder schema for unregistered fields so that + # metadata.field_names is complete and update_production_status() can recognize them. + if mode == "insert": + for field_name in data_fields: + if field_name not in field_schema: + field_schema[field_name] = { + "dtype": None, + "shape": None, + "is_nested": False, + "is_non_tensor": False, + } + + if mode == "fetch": + production_status = np.ones(batch_size, dtype=np.int8) + elif mode == "insert": + production_status = np.zeros(batch_size, dtype=np.int8) + else: # force_fetch + production_status = np.zeros(batch_size, dtype=np.int8) + if partition.production_status is not None and data_fields: + field_indices = [ + partition.field_name_mapping.get(field_name) + for field_name in data_fields + if field_name in partition.field_name_mapping + ] + if field_indices: + for i, global_idx in enumerate(batch_global_indexes): + if global_idx < partition.production_status.shape[0]: + sample_status = partition.production_status[global_idx, field_indices] + if torch.all(sample_status == 1): + production_status[i] = 1 + + custom_meta_dict = partition.get_custom_meta(batch_global_indexes) custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields) - batch_meta = BatchMeta(samples=samples) - batch_meta.update_custom_meta([custom_meta.get(idx, {}) for idx in batch_meta.global_indexes]) - batch_meta._custom_backend_meta.update(custom_backend_meta) + # Convert controller dict[int, dict] → BatchMeta list[dict] (aligned with batch_global_indexes) + custom_meta_list = [custom_meta_dict.get(global_index, {}) for global_index in batch_global_indexes] + custom_backend_meta_list = [custom_backend_meta.get(global_index, {}) for global_index in batch_global_indexes] + + batch_meta = BatchMeta( + global_indexes=batch_global_indexes, + partition_ids=[partition_id] * batch_size, + field_schema=field_schema, + production_status=production_status, + custom_meta=custom_meta_list, + _custom_backend_meta=custom_backend_meta_list, + ) return batch_meta def clear_partition(self, partition_id: str, clear_consumption: bool = True): @@ -1528,9 +1596,9 @@ def kv_retrieve_meta( verified_global_indexes ) data_fields = [] - for fname, col_idx in partition.field_name_mapping.items(): + for field_name, col_idx in partition.field_name_mapping.items(): if col_idx < len(col_mask) and col_mask[col_idx]: - data_fields.append(fname) + data_fields.append(field_name) metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch") diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 64134bfc..d1c92dbf 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -20,15 +20,12 @@ import os from collections import defaultdict from dataclasses import dataclass +from types import MappingProxyType from typing import Any, Optional import numpy as np import torch from tensordict import TensorDict -from tensordict.tensorclass import NonTensorData, NonTensorStack -from torch import Tensor - -from transfer_queue.utils.enum_utils import ProductionStatus logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -40,241 +37,235 @@ logger.addHandler(handler) -# TODO: Add UT for metadata operations -@dataclass -class FieldMeta: - """Records the metadata of a single data field (name, dtype, shape, etc.).""" - - name: str - dtype: Optional[Any] # Data type (e.g., torch.float32, numpy.float32) - shape: Optional[Any] # Data shape (e.g., torch.Size([3, 224, 224]), (3, 224, 224)) - production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED - - def __str__(self) -> str: - return ( - f"FieldMeta(name='{self.name}', dtype={self.dtype}, " - f"shape={self.shape}, production_status={self.production_status})" - ) - - @property - def is_ready(self) -> bool: - """Check if this field is ready for consumption""" - return self.production_status == ProductionStatus.READY_FOR_CONSUME +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- - @classmethod - def from_dict(cls, data: dict) -> "FieldMeta": - """Create FieldMeta from dictionary.""" - return cls( - name=data["name"], - dtype=data["dtype"], - shape=data["shape"], - production_status=ProductionStatus(str(data["production_status"])) - if isinstance(data["production_status"], int | str) - else data["production_status"], - ) +def _extra_info_values_equal(a: Any, b: Any) -> bool: + """Compare two extra_info values for equality. -@dataclass -class SampleMeta: - """Records the metadata of a single data sample (stored as a row in the data system).""" + Handles torch.Tensor, np.ndarray specially to avoid ambiguous truth values. + """ + if type(a) is not type(b): + return False + if isinstance(a, torch.Tensor): + return torch.equal(a, b) + if isinstance(a, np.ndarray): + return np.array_equal(a, b) + try: + return a == b + except Exception: + return False + + +class _SampleView: + """Lazy read-only view of a single sample row in a columnar BatchMeta. + + All returned dicts are ``MappingProxyType`` – attempts to mutate them + raise ``TypeError``, making it obvious that this is a snapshot view. + """ - partition_id: str # Partition id, used for data versioning - global_index: int # Global row index, uniquely identifies a data sample - fields: dict[str, FieldMeta] # Fields of interest for this sample + __slots__ = ("_batch", "_idx") - def __post_init__(self): - """Initialize is_ready property based on field readiness""" - # Check if all fields are ready and update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - - def __str__(self) -> str: - return f"SampleMeta(partition_id={self.partition_id}, global_index={self.global_index})" + def __init__(self, batch: "BatchMeta", idx: int) -> None: + self._batch = batch + self._idx = idx @property - def field_names(self) -> list[str]: - """Get list of field names for this sample""" - return list(self.fields.keys()) + def global_index(self) -> int: + """Return the global sample index for this sample.""" + return self._batch.global_indexes[self._idx] @property - def batch_index(self) -> int: - """Get the batch index of this sample (to be set by BatchMeta)""" - return getattr(self, "_batch_index", -1) - - def get_field_by_name(self, name: str) -> Optional[FieldMeta]: - """Get FieldMeta by field name""" - return self.fields.get(name) - - def has_field(self, name: str) -> bool: - """Check if this sample has a specific field""" - return name in self.fields + def partition_id(self) -> str: + """Return the partition ID for this sample.""" + return self._batch.partition_ids[self._idx] - def is_field_ready(self, field_name: str) -> bool: - """Check if a specific field is ready for consumption""" - field = self.fields.get(field_name) - return field.is_ready if field else False - - def add_fields(self, fields: dict[str, FieldMeta]) -> "SampleMeta": - """ - Add new fields to this sample. New fields will be initialized with given dtype, shape - and production_status (if provided). If not provided, default values (None, None, READY_FOR_CONSUME) - will be used. This modifies the sample in-place to include the new fields. - """ - self.fields = _union_fields(self.fields, fields) - # Update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - return self - - def select_fields(self, field_names: list[str]) -> "SampleMeta": - """ - Select specific fields from this sample. - This will construct a new SampleMeta instance containing only the specified fields. + @property + def production_status(self) -> int: + """Return the production status for this sample.""" + return int(self._batch.production_status[self._idx]) - Args: - field_names (list[str]): List of field names to retain. + @property + def custom_meta(self) -> "MappingProxyType[str, Any]": + """Read-only view of per-sample custom metadata.""" + return MappingProxyType(self._batch.custom_meta[self._idx]) - Returns: - SampleMeta: A new SampleMeta instance containing only the specified fields. - """ - selected_fields = {name: self.fields[name] for name in field_names if name in self.fields} + @property + def fields(self) -> "MappingProxyType[str, MappingProxyType]": + """Read-only per-sample field schema. + + For nested-tensor fields the batch-level ``per_sample_shapes`` list is + replaced by a single ``shape`` entry for *this* sample, so callers + always see ``fields['x']['shape']`` as a tuple (not a list-of-tuples). + """ + result: dict[str, MappingProxyType] = {} + for name, meta in self._batch.field_schema.items(): + per_sample = meta.get("per_sample_shapes") + if per_sample is not None: + sample_meta = {k: v for k, v in meta.items() if k != "per_sample_shapes"} + sample_meta["shape"] = per_sample[self._idx] + else: + sample_meta = dict(meta) + result[name] = MappingProxyType(sample_meta) + return MappingProxyType(result) - # construct new SampleMeta instance - # TODO(tianyi): (maybe) move _custom_backend_meta and custom_meta to FieldMeta level? - selected_sample_meta = SampleMeta( - fields=selected_fields, - partition_id=self.partition_id, - global_index=self.global_index, + def __repr__(self) -> str: + return ( + f"_SampleView(global_index={self.global_index}, " + f"partition_id={self.partition_id!r}, " + f"production_status={self.production_status}, " + f"fields={list(self._batch.field_schema.keys())})" ) - return selected_sample_meta - - def union(self, other: "SampleMeta", validate: bool = True) -> "SampleMeta": - """ - Create a union of this sample's fields with another sample's fields. - Assume both samples have the same global index. If fields overlap, the - fields in this sample will be replaced by the other sample's fields. - Args: - other: Another SampleMeta to union with - validate: Whether to validate union conditions +class _SampleViewList: + """Lazy indexable list returned by BatchMeta.samples. - Returns: - New SampleMeta with unioned fields (None if validation fails) - """ - if validate: - if self.global_index != other.global_index: - raise ValueError( - f"Error: Global indexes ({self.global_index} and {other.global_index}) do not match for union." - ) + Supports: indexing (samples[i]), len(), and iteration. + """ - # Merge fields - self.fields = _union_fields(self.fields, other.fields) + __slots__ = ("_batch",) - # Update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - return self + def __init__(self, batch: "BatchMeta") -> None: + self._batch = batch - @property - def is_ready(self) -> bool: - """Check if all fields in this sample are ready for consumption""" - return getattr(self, "_is_ready", False) + def __len__(self) -> int: + return len(self._batch.global_indexes) - @property - def production_status(self) -> dict[str, ProductionStatus]: - """Get production status for all fields (backward compatibility)""" - return {name: field.production_status for name, field in self.fields.items()} + def __getitem__(self, idx: int) -> _SampleView: + return _SampleView(self._batch, idx) - @classmethod - def from_dict(cls, data: dict) -> "SampleMeta": - """Create SampleMeta from dictionary.""" - fields = { - name: FieldMeta.from_dict(field_data) if isinstance(field_data, dict) else field_data - for name, field_data in data["fields"].items() - } - return cls( - partition_id=data["partition_id"], - global_index=data["global_index"], - fields=fields, - ) + def __iter__(self): + return (_SampleView(self._batch, i) for i in range(len(self))) @dataclass class BatchMeta: - """Records the metadata of a batch of data samples.""" - - samples: list[SampleMeta] + """Records the metadata of a batch of data samples with optimized field-level schema. + + This is the O(BxF) optimized version that stores field metadata at the field level + instead of per-sample, reducing storage from O(B*F) to O(F). + + Attributes: + global_indexes: List of global sample indices in this batch. + partition_ids: List of partition IDs corresponding to each sample. + field_schema: Field-level metadata {field_name: {dtype, shape, is_nested, is_non_tensor, per_sample_shapes}}. + production_status: Vectorized production status, shape (B,) where B is batch size. + extra_info: Additional batch-level information. + custom_meta: Per-sample user-defined metadata, list aligned with global_indexes. + _custom_backend_meta: Per-sample per-field storage backend metadata, list aligned with global_indexes. + """ - # external meta for non-sample level information + global_indexes: list[int] + partition_ids: list[str] + # O(F) field-level metadata: {field_name: {dtype, shape, is_nested, is_non_tensor}} + field_schema: dict[str, dict[str, Any]] = dataclasses.field(default_factory=dict) + # O(B) vectorized production status; always np.ndarray after __post_init__ (never None) + production_status: np.ndarray = dataclasses.field(default=None, repr=False) # type: ignore[assignment] extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) - - # user-defined meta for each sample - custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) - - # internal meta for different storage backends in per-sample per-field level - _custom_backend_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) + # user-defined meta for each sample (sample-level), list aligned with global_indexes + custom_meta: list[dict[str, Any]] = dataclasses.field(default_factory=list) + # internal meta for different storage backends (per-sample per-field level), list aligned with global_indexes + _custom_backend_meta: list[dict[str, Any]] = dataclasses.field(default_factory=list) def __post_init__(self): """Initialize all computed properties during initialization""" - self.samples = copy.deepcopy(self.samples) - self.extra_info = copy.deepcopy(self.extra_info) + self.global_indexes = list(self.global_indexes) + self.partition_ids = list(self.partition_ids) + self.field_schema = {k: dict(v) for k, v in self.field_schema.items()} + self.extra_info = dict(self.extra_info) + + # Validation + if len(self.global_indexes) != len(self.partition_ids): + raise ValueError( + f"Length mismatch: global_indexes has {len(self.global_indexes)}, " + f"partition_ids has {len(self.partition_ids)}" + ) - # Basic properties - object.__setattr__(self, "_size", len(self.samples)) - object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) + batch_size = len(self.global_indexes) - # Pre-compute all list properties for better performance - if self.samples: - for idx, sample in enumerate(self.samples): - object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly + if self.production_status is not None: + self.production_status = np.array(self.production_status, dtype=np.int8, copy=True) - object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) + if len(self.production_status) != batch_size: + raise ValueError(f"production_status length {len(self.production_status)} != batch_size {batch_size}") + else: + # Default: all NOT_PRODUCED (including empty batches) + self.production_status = np.zeros(batch_size, dtype=np.int8) + + for field_name, meta in self.field_schema.items(): + if meta.get("per_sample_shapes") is not None: + if len(meta["per_sample_shapes"]) != batch_size: + raise ValueError( + f"Field '{field_name}' per_sample_shapes length {len(meta['per_sample_shapes'])} " + f"!= batch_size {batch_size}" + ) - # check if all samples have the same field names - first_sample_field_names = sorted(self.samples[0].field_names) - if not all(sorted(sample.field_names) == first_sample_field_names for sample in self.samples): - raise ValueError("All samples in BatchMeta must have the same field_names.") - object.__setattr__(self, "_field_names", first_sample_field_names) + self._size = batch_size + self._field_names = sorted(self.field_schema.keys()) - object.__setattr__(self, "_partition_ids", [sample.partition_id for sample in self.samples]) + is_ready = batch_size > 0 and bool(np.all(self.production_status == 1)) + self._is_ready = is_ready - # filter custom_meta and _custom_backend_meta - self.custom_meta = copy.deepcopy( - {k: self.custom_meta[k] for k in self.global_indexes if k in self.custom_meta} - ) - self._custom_backend_meta = copy.deepcopy( - {k: self._custom_backend_meta[k] for k in self.global_indexes if k in self._custom_backend_meta} - ) + # Validate or initialize columnar custom_meta / _custom_backend_meta + if not self.custom_meta: + self.custom_meta = [{} for _ in range(batch_size)] else: - self.custom_meta = {} - self._custom_backend_meta = {} - object.__setattr__(self, "_global_indexes", []) - object.__setattr__(self, "_field_names", []) - object.__setattr__(self, "_partition_ids", []) + self.custom_meta = [dict(d) for d in self.custom_meta] + if len(self.custom_meta) != batch_size: + raise ValueError(f"custom_meta length {len(self.custom_meta)} != batch_size {batch_size}") + if not self._custom_backend_meta: + self._custom_backend_meta = [{} for _ in range(batch_size)] + else: + self._custom_backend_meta = [dict(d) for d in self._custom_backend_meta] + if len(self._custom_backend_meta) != batch_size: + raise ValueError( + f"_custom_backend_meta length {len(self._custom_backend_meta)} != batch_size {batch_size}" + ) + + def __setstate__(self, state): + """Restore instance from pickle/Ray deserialization. + + Python dataclass pickle skips __init__/__post_init__, so the + .copy() guard for production_status is bypassed. Ray Arrow + zero-copy deserialization produces read-only numpy arrays. + This method ensures writability after deserialization. + """ + self.__dict__.update(state) + if isinstance(self.production_status, np.ndarray) and not self.production_status.flags.writeable: + self.production_status = self.production_status.copy() @property def size(self) -> int: """Return the number of samples in this batch""" return getattr(self, "_size", 0) - @property - def global_indexes(self) -> list[int]: - """Get all global indexes in this batch""" - return getattr(self, "_global_indexes", []) - @property def field_names(self) -> list[str]: """Get all unique field names in this batch""" return getattr(self, "_field_names", []) + @property + def samples(self) -> _SampleViewList: + """Lazy per-sample view: supports samples[i].fields['a'], len(samples), for s in samples.""" + return _SampleViewList(self) + @property def is_ready(self) -> bool: """Check if all samples in this batch are ready for consumption""" - # TODO: get ready status from controller realtime return getattr(self, "_is_ready", False) - @property - def partition_ids(self) -> list[str]: - """Get partition ids for all samples in this batch as a list (one per sample)""" - return getattr(self, "_partition_ids", []) + # ==================== Extra Info Methods ==================== + + def get_extra_info(self, key: str, default: Any = None) -> Any: + """Get extra info by key""" + return self.extra_info.get(key, default) + + def set_extra_info(self, key: str, value: Any) -> None: + """Set extra info by key""" + self.extra_info[key] = value def get_all_extra_info(self) -> dict[str, Any]: """Get all extra_info as a dictionary (deep copy for immutability). @@ -285,49 +276,44 @@ def get_all_extra_info(self) -> dict[str, Any]: return copy.deepcopy(self.extra_info) def update_extra_info(self, info_dict: dict[str, Any]) -> None: - """ - Update extra_info with multiple key-value pairs. - - This method updates the extra_info dictionary with the provided key-value pairs. - Existing keys will be overwritten with new values. + """Update extra_info with multiple key-value pairs. Args: info_dict: Dictionary of key-value pairs to add/update in extra_info """ self.extra_info.update(info_dict) - def clear_extra_info(self) -> None: - """ - Clear all extra_info. + def remove_extra_info(self, key: str) -> Any: + """Remove extra info by key and return its value""" + return self.extra_info.pop(key, None) - This method removes all key-value pairs from the extra_info dictionary. - """ + def clear_extra_info(self) -> None: + """Clear all extra_info.""" self.extra_info.clear() + def has_extra_info(self, key: str) -> bool: + """Check if extra info contains a specific key""" + return key in self.extra_info + + # ==================== Custom Meta Methods (User Layer) ==================== + def get_all_custom_meta(self) -> list[dict[str, Any]]: - """ - Get all custom_meta as a list of dictionary. + """Get all custom_meta as a list of dictionary (one per sample, in global_indexes order). Returns: A deep copy of the custom_meta list """ - custom_meta = [self.custom_meta.get(i, {}) for i in self.global_indexes] - return copy.deepcopy(custom_meta) + return copy.deepcopy(self.custom_meta) def update_custom_meta(self, custom_meta: list[dict[str, Any]]): - """ - Update custom_meta with a list of dictionary of custom metadata. - - This method updates the custom_meta dictionary with the provided metadata. - Existing keys will be overwritten with new values. + """Update custom_meta with a list of dictionary of custom metadata. Args: - custom_meta: list of custom_meta dictionary + custom_meta: list of custom_meta dictionary (one per sample, in global_indexes order) Raises: - ValueError: If the length of custom_meta cannot match the batch size + ValueError: If the length of custom_meta does not match the batch size """ - if custom_meta is None: return @@ -336,142 +322,192 @@ def update_custom_meta(self, custom_meta: list[dict[str, Any]]): f"The length of custom_meta list {len(custom_meta)} must match the batch size: {self.size}" ) - custom_meta_dict: dict[int, dict[str, Any]] = { - self.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta)) - } - - self.custom_meta.update(custom_meta_dict) + for i, meta in enumerate(custom_meta): + self.custom_meta[i].update(meta) def clear_custom_meta(self) -> None: - """ - Clear all custom_meta. + """Clear all custom_meta.""" + self.custom_meta = [{} for _ in range(self.size)] - This method removes all entries from the custom_meta dictionary. - """ - self.custom_meta.clear() + # ==================== Core BatchMeta Operations ==================== def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": - """ - Add new fields from a TensorDict to all samples in this batch. - This modifies each sample in-place to include the new fields. + """Add new fields from a TensorDict to all samples in this batch. + This modifies the batch in-place to include the new fields. Args: tensor_dict (TensorDict): The input TensorDict containing new fields. set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True. """ - fields = _extract_field_metas(tensor_dict, set_all_ready) - - if fields: - if len(self.samples) != len(fields): - raise ValueError(f"add_fields length mismatch: samples={len(self.samples)} vs fields={len(fields)}") - for idx, sample in enumerate(self.samples): - sample.add_fields(fields=fields[idx]) - - # Update batch-level fields cache - if self.samples: - object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names)) - object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) + batch_size = tensor_dict.batch_size[0] + if batch_size != self.size: + raise ValueError(f"add_fields batch size mismatch: self.size={self.size} vs tensor_dict={batch_size}") + + for name, value in tensor_dict.items(): + # Determine if this is a nested tensor + is_nested = isinstance(value, torch.Tensor) and value.is_nested + + first_item = None + if is_nested: + unbound = value.unbind() + first_item = unbound[0] if unbound else None + else: + first_item = value[0] if len(value) > 0 else None + + # Determine if this is non-tensor data. + # When first_item is None (empty field), we cannot determine type—leave as None. + is_non_tensor = not isinstance(first_item, torch.Tensor) if first_item is not None else None + + field_meta = { + "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), + "shape": getattr(first_item, "shape", None) if not is_nested else None, + "is_nested": is_nested, + "is_non_tensor": is_non_tensor, + } + + # For nested tensors, record per-sample shapes + if is_nested: + field_meta["per_sample_shapes"] = [tuple(t.shape) for t in value.unbind()] + + self.field_schema[name] = field_meta + + if set_all_ready: + self.production_status[:] = 1 + + self._field_names = sorted(self.field_schema.keys()) + + self._is_ready = self.size > 0 and bool(np.all(self.production_status == 1)) + return self - def select_samples(self, indexes: list[int]) -> "BatchMeta": - """ - Select specific samples from this batch. + def select_samples(self, sample_indices: list[int]) -> "BatchMeta": + """Select specific samples from this batch. This will construct a new BatchMeta instance containing only the specified samples. Args: - indexes (list[int]): List of indexes (relative to this batch, not global_indexes) - to retain. + sample_indices (list[int]): List of sample indices (relative to this batch) to retain. Returns: BatchMeta: A new BatchMeta instance containing only the specified samples. """ + if any(i < 0 or i >= self.size for i in sample_indices): + raise ValueError(f"Sample indices must be in range [0, {self.size})") + + new_global_indexes = [self.global_indexes[i] for i in sample_indices] + new_partition_ids = [self.partition_ids[i] for i in sample_indices] - if any(i < 0 or i >= len(self.samples) for i in indexes): - raise ValueError(f"Sample indices must be in range [0, {len(self.samples)})") + # Select production_status + new_production_status = self.production_status[sample_indices] - selected_samples = [self.samples[i] for i in indexes] + new_field_schema = {} + for field_name, meta in self.field_schema.items(): + new_meta = copy.deepcopy(meta) + if meta.get("per_sample_shapes") is not None: + new_meta["per_sample_shapes"] = [meta["per_sample_shapes"][i] for i in sample_indices] + new_field_schema[field_name] = new_meta - global_indexes = [self.global_indexes[i] for i in indexes] - selected_custom_meta = {i: self.custom_meta[i] for i in global_indexes if i in self.custom_meta} - selected_custom_backend_meta = { - i: self._custom_backend_meta[i] for i in global_indexes if i in self._custom_backend_meta - } + new_custom_meta = [copy.deepcopy(self.custom_meta[i]) for i in sample_indices] - # construct new BatchMeta instance - selected_batch_meta = BatchMeta( - samples=selected_samples, + new_custom_backend_meta = [copy.deepcopy(self._custom_backend_meta[i]) for i in sample_indices] + + return BatchMeta( + global_indexes=new_global_indexes, + partition_ids=new_partition_ids, + field_schema=new_field_schema, + production_status=new_production_status, extra_info=self.extra_info, - custom_meta=selected_custom_meta, - _custom_backend_meta=selected_custom_backend_meta, + custom_meta=new_custom_meta, + _custom_backend_meta=new_custom_backend_meta, ) - return selected_batch_meta - def select_fields(self, field_names: list[str]) -> "BatchMeta": - """ - Select specific fields from all samples in this batch. + """Select specific fields from all samples in this batch. This will construct a new BatchMeta instance containing only the specified fields. Args: field_names (list[str]): List of field names to retain. Returns: - BatchMeta: A new BatchMeta instance containing only the specified fields from all samples. + BatchMeta: A new BatchMeta instance containing only the specified fields. """ - # select fields for each SampleMeta - new_samples = [sample.select_fields(field_names=field_names) for sample in self.samples] - - # select fields in _custom_backend_meta - selected_custom_backend_meta = {} - for idx in self.global_indexes: - if idx in self._custom_backend_meta: - custom_backend_meta_idx = self._custom_backend_meta[idx] - - selected_custom_backend_meta[idx] = { - field: custom_backend_meta_idx[field] for field in field_names if field in custom_backend_meta_idx - } - - # construct new BatchMeta instance - new_batch_meta = BatchMeta( - samples=new_samples, - extra_info=self.extra_info, - custom_meta=self.custom_meta, + new_field_schema = {} + for field_name in field_names: + if field_name in self.field_schema: + new_field_schema[field_name] = copy.deepcopy(self.field_schema[field_name]) + + selected_custom_backend_meta = [ + {f: v for f, v in m.items() if f.startswith("_") or f in field_names} for m in self._custom_backend_meta + ] + + return BatchMeta( + global_indexes=self.global_indexes, + partition_ids=self.partition_ids, + field_schema=new_field_schema, + production_status=self.production_status.copy(), + extra_info=copy.deepcopy(self.extra_info), + custom_meta=copy.deepcopy(self.custom_meta), _custom_backend_meta=selected_custom_backend_meta, ) - return new_batch_meta + def with_data_fields(self, field_names: list[str]) -> "BatchMeta": + """Return a new BatchMeta with the given data fields, replacing the current field_schema. - def __len__(self) -> int: - """Return the number of samples in this batch.""" - return len(self.samples) + Unlike ``select_fields``, this method allows specifying field names that are not + yet present in the current ``field_schema`` (e.g. fields added by a subsequent + ``put`` call on a subset of samples). Unknown fields are included in the new + ``field_schema`` with an empty metadata dict so that ``get_data`` can retrieve + them from the storage backend. - def __getitem__(self, item): - if isinstance(item, int | np.integer): - sample_meta = self.samples[item] if self.samples else [] - global_idx = self.global_indexes[item] + Args: + field_names (list[str]): List of field names to request. May include fields + not present in the current ``field_schema``. - if global_idx in self.custom_meta: - custom_meta = {global_idx: self.custom_meta[global_idx]} + Returns: + BatchMeta: A new BatchMeta instance whose ``field_schema`` contains exactly + the requested fields (existing metadata is preserved where available). + """ + new_field_schema = {} + for field_name in field_names: + if field_name in self.field_schema: + new_field_schema[field_name] = copy.deepcopy(self.field_schema[field_name]) else: - custom_meta = {} + # Unknown field — include with empty schema so get_data can fetch it. + new_field_schema[field_name] = {} - if global_idx in self._custom_backend_meta: - custom_backend_meta = {global_idx: self._custom_backend_meta[global_idx]} - else: - custom_backend_meta = {} + selected_custom_backend_meta = [ + {f: v for f, v in m.items() if f.startswith("_") or f in field_names} for m in self._custom_backend_meta + ] - return BatchMeta( - samples=[sample_meta], - extra_info=self.extra_info, - custom_meta=custom_meta, - _custom_backend_meta=custom_backend_meta, - ) + return BatchMeta( + global_indexes=self.global_indexes, + partition_ids=self.partition_ids, + field_schema=new_field_schema, + production_status=self.production_status.copy(), + extra_info=copy.deepcopy(self.extra_info), + custom_meta=copy.deepcopy(self.custom_meta), + _custom_backend_meta=selected_custom_backend_meta, + ) + + def __len__(self) -> int: + """Return the number of samples in this batch.""" + return self.size + + def __getitem__(self, item) -> "BatchMeta": + if isinstance(item, int | np.integer): + if item < 0: + item += self.size + if item < 0 or item >= self.size: + raise IndexError("BatchMeta index out of range") + return self.select_samples([item]) + elif isinstance(item, slice): + start, stop, step = item.indices(self.size) + indices = list(range(start, stop, step)) + return self.select_samples(indices) else: - raise TypeError(f"Indexing with {type(item)} is not supported now!") + raise TypeError(f"Indexing with {type(item)} is not supported.") def chunk(self, chunks: int) -> list["BatchMeta"]: - """ - Split this batch into smaller chunks. + """Split this batch into smaller chunks. Args: chunks: number of chunks @@ -480,7 +516,7 @@ def chunk(self, chunks: int) -> list["BatchMeta"]: List of smaller BatchMeta chunks """ chunk_list = [] - n = len(self.samples) + n = self.size if n < chunks: logger.warning( @@ -494,47 +530,57 @@ def chunk(self, chunks: int) -> list["BatchMeta"]: start = 0 for i in range(chunks): - # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size) current_chunk_size = base_size + 1 if i < remainder else base_size end = start + current_chunk_size - chunk_samples = self.samples[start:end] - global_indexes = self.global_indexes[start:end] - chunk_custom_meta = {i: self.custom_meta[i] for i in global_indexes if i in self.custom_meta} - chunk_custom_backend_meta = { - i: self._custom_backend_meta[i] for i in global_indexes if i in self._custom_backend_meta - } - chunk = BatchMeta( - samples=chunk_samples, - extra_info=self.extra_info, - custom_meta=chunk_custom_meta, - _custom_backend_meta=chunk_custom_backend_meta, - ) + indices = list(range(start, end)) + chunk = self.select_samples(indices) chunk_list.append(chunk) start = end return chunk_list - def chunk_by_partition( - self, - ) -> list["BatchMeta"]: - """ - Split this batch into smaller chunks according to partition_ids. + def chunk_by_partition(self) -> list["BatchMeta"]: + """Split this batch into smaller chunks according to partition_ids. Return: List of smaller BatchMeta chunks, each chunk has samples with identical partition_id """ - grouped_indexes = defaultdict(list) - for partition_id, indexes in zip(self.partition_ids, range(self.size), strict=False): + for partition_id, indexes in zip(self.partition_ids, range(self.size), strict=True): grouped_indexes[partition_id].append(indexes) chunk_list = [self.select_samples(idx) for idx in grouped_indexes.values()] - return chunk_list + def union(self, other: "BatchMeta") -> "BatchMeta": + """Return the union of this BatchMeta and another BatchMeta. + Samples with global_indexes already present in this batch are ignored from the other batch. + + Args: + other: The other BatchMeta to merge with. + + Returns: + BatchMeta: A new merged BatchMeta. + """ + if not other or other.size == 0: + return self + if self.size == 0: + return other + + self_indexes = set(self.global_indexes) + unique_indices_in_other = [i for i, idx in enumerate(other.global_indexes) if idx not in self_indexes] + + if not unique_indices_in_other: + return self + + if len(unique_indices_in_other) == other.size: + return BatchMeta.concat([self, other]) + + other_unique = other.select_samples(unique_indices_in_other) + return BatchMeta.concat([self, other_unique]) + @classmethod def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": - """ - Concatenate multiple BatchMeta chunks into one large batch. + """Concatenate multiple BatchMeta chunks into one large batch. Args: data: List of BatchMeta chunks to concatenate @@ -548,214 +594,139 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": """ if not data: logger.warning("Try to concat empty BatchMeta chunks. Returning empty BatchMeta.") - return BatchMeta(samples=[], extra_info={}, custom_meta={}, _custom_backend_meta={}) + return BatchMeta.empty() # skip empty chunks - data = [chunk for chunk in data if chunk and len(chunk.samples) > 0] + data = [chunk for chunk in data if chunk and chunk.size > 0] if len(data) == 0: logger.warning("No valid BatchMeta chunks to concatenate. Returning empty BatchMeta.") - return BatchMeta(samples=[], extra_info={}, custom_meta={}, _custom_backend_meta={}) + return BatchMeta.empty() if validate: base_fields = data[0].field_names - for chunk in data: + for i, chunk in enumerate(data): if chunk.field_names != base_fields: - raise ValueError("Error: Field names do not match for concatenation.") + raise ValueError( + f"BatchMeta.concat: field_names mismatch at chunk[{i}]. " + f"Expected {base_fields}, got {chunk.field_names}. " + f"Extra in chunk: {set(chunk.field_names) - set(base_fields)}, " + f"Missing from chunk: {set(base_fields) - set(chunk.field_names)}" + ) - # Combine all samples - all_samples = list(itertools.chain.from_iterable(chunk.samples for chunk in data)) + # Validate field_schema dtype and is_nested consistency across chunks + for field_name in base_fields: + base_meta = data[0].field_schema.get(field_name, {}) + base_dtype = base_meta.get("dtype") + base_is_nested = base_meta.get("is_nested", False) + for i, chunk in enumerate(data[1:], start=1): + chunk_meta = chunk.field_schema.get(field_name, {}) + if chunk_meta.get("dtype") != base_dtype: + raise ValueError( + f"Field '{field_name}' dtype mismatch in concat: " + f"chunk[0]={base_dtype}, chunk[{i}]={chunk_meta.get('dtype')}" + ) + if chunk_meta.get("is_nested", False) != base_is_nested: + raise ValueError( + f"Field '{field_name}' is_nested mismatch in concat: " + f"chunk[0]={base_is_nested}, chunk[{i}]={chunk_meta.get('is_nested', False)}" + ) + + all_global_indexes = list(itertools.chain.from_iterable(chunk.global_indexes for chunk in data)) + all_partition_ids = list(itertools.chain.from_iterable(chunk.partition_ids for chunk in data)) + + all_production_status = np.concatenate([chunk.production_status for chunk in data]) + + all_field_schema: dict[str, dict[str, Any]] = {} + first_chunk = data[0] + for field_name, meta in first_chunk.field_schema.items(): + all_field_schema[field_name] = { + "dtype": meta.get("dtype"), + "shape": meta.get("shape"), + "is_nested": meta.get("is_nested", False), + "is_non_tensor": meta.get("is_non_tensor", False), + } + if any(chunk.field_schema.get(field_name, {}).get("per_sample_shapes") for chunk in data): + all_shapes = [] + for chunk in data: + chunk_meta = chunk.field_schema.get(field_name, {}) + chunk_shapes = chunk_meta.get("per_sample_shapes") + if chunk_shapes: + all_shapes.extend(chunk_shapes) + else: + all_shapes.extend([None] * chunk.size) + all_field_schema[field_name]["per_sample_shapes"] = all_shapes - # Merge all extra_info dictionaries from the chunks - merged_extra_info = dict() - merged_custom_meta = dict() - merged_custom_backend_meta = dict() + all_custom_meta: list[dict[str, Any]] = [] + all_custom_backend_meta: list[dict[str, Any]] = [] + for chunk in data: + all_custom_meta.extend(chunk.custom_meta) + all_custom_backend_meta.extend(chunk._custom_backend_meta) - values_by_key = defaultdict(list) + # Merge extra_info: batch-level metadata with equality check + all_keys: set[str] = set() for chunk in data: - # For the sample-level custom_meta and field-level _custom_backend_meta, we directly update the dict. - merged_custom_meta.update(chunk.custom_meta) - merged_custom_backend_meta.update(chunk._custom_backend_meta) - - for key, value in chunk.extra_info.items(): - values_by_key[key].append(value) - - # For the batch-level extra_info, we concat the tensor/NonTensorStack/NonTensorData/list - # objects to prevent information losses. - for key, values in values_by_key.items(): - if all(isinstance(v, torch.Tensor) for v in values): - try: - if all(v.dim() == 0 for v in values): - merged_extra_info[key] = torch.cat([v.unsqueeze(0) for v in values], dim=0) - else: - merged_extra_info[key] = torch.cat(values, dim=0) - except RuntimeError as e: - logger.warning( - f"BatchMeta.concat try to use torch.cat(dim=0) to merge extra_info key '{key}'" - f" fails, with RuntimeError {e}. Falling back to use list." + all_keys.update(chunk.extra_info.keys()) + + merged_extra_info = {} + base_keys = set(data[0].extra_info.keys()) + + # Warn if chunks have different key sets + if any(set(chunk.extra_info.keys()) != base_keys for chunk in data[1:]): + logger.warning("BatchMeta.concat: extra_info key sets differ across chunks, performing union of keys.") + + for key in all_keys: + values = [chunk.extra_info[key] for chunk in data if key in chunk.extra_info] + # Check all values are equal + first = values[0] + for i, v in enumerate(values[1:], start=1): + if not _extra_info_values_equal(first, v): + raise ValueError( + f"BatchMeta.concat: extra_info key '{key}' has conflicting values " + f"across chunks and cannot be merged. " + f"chunk[0]={first!r}, chunk[{i}]={v!r}" ) - merged_extra_info[key] = values - elif all(isinstance(v, NonTensorStack | NonTensorData) for v in values): - merged_extra_info[key] = torch.stack(values) - elif all(isinstance(v, list) for v in values): - merged_extra_info[key] = list(itertools.chain.from_iterable(values)) - else: - merged_extra_info[key] = values[-1] + merged_extra_info[key] = first return BatchMeta( - samples=all_samples, + global_indexes=all_global_indexes, + partition_ids=all_partition_ids, + field_schema=all_field_schema, + production_status=all_production_status, extra_info=merged_extra_info, - custom_meta=merged_custom_meta, - _custom_backend_meta=merged_custom_backend_meta, + custom_meta=all_custom_meta, + _custom_backend_meta=all_custom_backend_meta, ) - def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]: - """ - Create a union of this batch's fields with another batch's fields. - Assume both batches have the same global indices and matching partition_ids for all samples. - If fields overlap, the fields in this batch will be replaced by the other batch's fields. - - Args: - other: Another BatchMeta to union with - validate: Whether to validate union conditions - - Returns: - New BatchMeta with unioned fields - - Raises: - ValueError: If validation fails (e.g., batch sizes or global indexes do not match) - """ - if validate: - if self.size != other.size: - raise ValueError("Error: Batch sizes do not match for union.") - - self_global_indexes = sorted(self.global_indexes) - other_global_indexes = sorted(other.global_indexes) - if self_global_indexes != other_global_indexes: - raise ValueError("Error: Global indexes do not match for union.") - - if self.partition_ids != other.partition_ids: - raise ValueError("Error: Partition IDs do not match for union.") - - # Create a mapping from global_index to SampleMeta in the other batch - other_sample_map = {sample.global_index: sample for sample in other.samples} - - # Merge samples - merged_samples = [] - for sample in self.samples: - if sample.global_index in other_sample_map: - other_sample = other_sample_map[sample.global_index] - merged_sample = sample.union(other_sample, validate=validate) - merged_samples.append(merged_sample) - else: - merged_samples.append(sample) - - # Merge extra info dictionaries - merged_extra_info = {**self.extra_info, **other.extra_info} - - # Merge custom_meta dictionaries - merged_custom_meta = {**self.custom_meta, **other.custom_meta} - - # Merge custom_backend_meta dictionaries - merged_custom_backend_meta = {} - for idx in self.global_indexes: - if idx in self._custom_backend_meta and idx in other._custom_backend_meta: - merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx], **other._custom_backend_meta[idx]} - elif idx in self._custom_backend_meta: - merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx]} - elif idx in other._custom_backend_meta: - merged_custom_backend_meta[idx] = {**other._custom_backend_meta[idx]} - - return BatchMeta( - samples=merged_samples, - extra_info=merged_extra_info, - custom_meta=merged_custom_meta, - _custom_backend_meta=merged_custom_backend_meta, - ) - - def reorder(self, indexes: list[int]): - """ - Reorder the SampleMeta in the BatchMeta according to the given indexes (must equal to the length of samples). - - The operation is performed in-place, modifying the current BatchMeta's SampleMeta order. - - To select a subset of samples or repeat specific samples, please use the non-inplace method select_samples(). - - Args: - indexes : list[int] - A list of integers specifying the new order of SampleMeta. Each integer - represents the current index of the SampleMeta in the BatchMeta. - """ - - if len(indexes) != self.size: - raise ValueError( - f"Attempted to reorder with indexes length {len(indexes)} that does not match samples length " - f"{self.size}. Please use non-inplace method select_samples() instead if you want to " - f"select a subset of samples or repeat specific samples." - ) - - if len(set(indexes)) != self.size: - raise ValueError( - f"Indexes={indexes} contain duplicates. Please use non-inplace method " - f"select_samples() instead if you want to select a subset of samples or repeat specific samples." - ) - - if any(i < 0 or i >= len(self.samples) for i in indexes): - raise ValueError(f"Reorder indexes must be in the range [0, {self.size}).") - - # Reorder the samples - reordered_samples = [self.samples[i] for i in indexes] - object.__setattr__(self, "samples", reordered_samples) - - # Update necessary attributes - self._update_after_reorder() - - def _update_after_reorder(self) -> None: - """Update related attributes specifically for the reorder operation""" - # Update batch_index for each sample - for idx, sample in enumerate(self.samples): - object.__setattr__(sample, "_batch_index", idx) - - # Update cached index lists - object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) - object.__setattr__(self, "_partition_ids", [sample.partition_id for sample in self.samples]) - - # Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder - - @classmethod - def from_samples( - cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None - ) -> "BatchMeta": + def reorder(self, indices: list[int]): + """Reorder the samples in the BatchMeta according to the given indices. + The operation is performed in-place. """ - Create a BatchMeta from a single SampleMeta or a list of SampleMeta objects. + if len(indices) != self.size: + raise ValueError(f"Indices length {len(indices)} mismatch batch size {self.size}") - Args: - samples: A single SampleMeta or a list of SampleMeta objects - extra_info: Optional additional information to store with the batch + if len(set(indices)) != self.size: + raise ValueError("Indices contain duplicates") - Returns: - BatchMeta instance containing the provided sample(s) + if any(i < 0 or i >= self.size for i in indices): + raise ValueError(f"Reorder indices must be in range [0, {self.size})") - Example: - >>> sample_meta = SampleMeta(...) - >>> batch_meta = BatchMeta.from_samples(sample_meta) + self.global_indexes = [self.global_indexes[i] for i in indices] + self.partition_ids = [self.partition_ids[i] for i in indices] - >>> sample_metas = [sample1, sample2, sample3] - >>> batch_meta = BatchMeta.from_samples(sample_metas, extra_info={"source": "training"}) - """ - if extra_info is None: - extra_info = {} + self.production_status = self.production_status[indices] - if isinstance(samples, SampleMeta): - samples = [samples] + for field_name, meta in self.field_schema.items(): + if meta.get("per_sample_shapes") is not None: + meta["per_sample_shapes"] = [meta["per_sample_shapes"][i] for i in indices] - return cls(samples=samples, extra_info=extra_info) + self.custom_meta = [self.custom_meta[i] for i in indices] + self._custom_backend_meta = [self._custom_backend_meta[i] for i in indices] @classmethod def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": - """ - Create an empty BatchMeta with no samples. + """Create an empty BatchMeta with no samples. Args: extra_info: Optional additional information to store with the batch @@ -768,77 +739,22 @@ def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": """ if extra_info is None: extra_info = {} - return cls(samples=[], extra_info=extra_info, custom_meta={}, _custom_backend_meta={}) + return cls( + global_indexes=[], + partition_ids=[], + field_schema={}, + production_status=None, + extra_info=extra_info, + custom_meta=[], + _custom_backend_meta=[], + ) def __str__(self): - sample_strs = ", ".join(str(sample) for sample in self.samples) return ( f"BatchMeta(size={self.size}, field_names={self.field_names}, is_ready={self.is_ready}, " - f"samples=[{sample_strs}], extra_info={self.extra_info})" + f"global_indexes={self.global_indexes}, extra_info={self.extra_info})" ) - @classmethod - def from_dict(cls, data: dict) -> "BatchMeta": - """Create BatchMeta from dictionary.""" - samples = [ - SampleMeta.from_dict(sample_data) if isinstance(sample_data, dict) else sample_data - for sample_data in data["samples"] - ] - return cls( - samples=samples, - extra_info=data.get("extra_info", {}), - custom_meta=data.get("custom_meta", {}), - _custom_backend_meta=data.get("_custom_backend_meta", {}), - ) - - -def _union_fields(fields1: dict[str, FieldMeta], fields2: dict[str, FieldMeta]) -> dict[str, FieldMeta]: - """Union two sample's fields. If fields overlap, the fields in fields1 will be replaced by fields2.""" - for name in fields2.keys(): - fields1[name] = fields2[name] - return fields1 - - -def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> list[dict[str, FieldMeta]]: - """ - Extract field metas from a TensorDict. If data in tensor_dict does not have dtype or shape attribute, - the corresponding dtype or shape will be set to None. - - Args: - tensor_dict (TensorDict): The input TensorDict. - set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. - Otherwise, set to NOT_PRODUCED. Default is True. - - Returns: - all_fields (list[dict[str, FieldMeta]]): A list of dictionaries containing field metadata. - """ - batch_size = tensor_dict.batch_size[0] - - production_status = ProductionStatus.READY_FOR_CONSUME if set_all_ready else ProductionStatus.NOT_PRODUCED - - # unbind nested tensor - results: dict = {} - for field in tensor_dict.keys(): - field_data = tensor_dict[field] - if batch_size > 1 and isinstance(field_data, Tensor) and field_data.is_nested: - results[field] = field_data.unbind() - else: - results[field] = field_data - - all_fields = [] - for idx in range(batch_size): - dict_of_field_meta = {} - for field_name in results.keys(): - dict_of_field_meta[field_name] = FieldMeta( - name=field_name, - dtype=getattr(results[field_name][idx], "dtype", None), - shape=getattr(results[field_name][idx], "shape", None), - production_status=production_status, - ) - all_fields.append(dict_of_field_meta) - - return all_fields - # ==================== KV Interface Metadata ==================== @dataclass @@ -889,9 +805,7 @@ def __str__(self): return f"KVBatchMeta(size={self.size}, field_names={self.fields}, extra_info={self.extra_info})" def select_keys(self, keys_to_select: list[str]) -> "KVBatchMeta": - """ - Select specific keys from this batch. - This will construct a new KVBatchMeta instance containing only the specified keys. + """Select specific keys from this batch. Args: keys_to_select (list[str]): List of keys to retain. @@ -903,7 +817,6 @@ def select_keys(self, keys_to_select: list[str]) -> "KVBatchMeta": ValueError: If duplicate keys exist in input param `keys_to_select`. RuntimeError: If `keys_to_select` contains keys that do not exist in this batch. """ - if len(set(keys_to_select)) != len(keys_to_select): raise ValueError("Contain duplicate keys.") @@ -925,14 +838,13 @@ def select_keys(self, keys_to_select: list[str]) -> "KVBatchMeta": ) def reorder(self, indexes: list[int]): - """ - Reorder the samples in this batch according to the specified indexes. + """Reorder the samples in this batch according to the specified indexes. The operation is performed in-place. Args: indexes : list[int] - A list of integers specifying the new order of SampleMeta. + A list of integers specifying the new order of samples. Raises: ValueError: If the size of input `indexes` does not match with the batch size. @@ -951,8 +863,7 @@ def reorder(self, indexes: list[int]): self.tags = [self.tags[i] for i in indexes] def chunk(self, chunks: int) -> list["KVBatchMeta"]: - """ - Split this batch into smaller chunks. + """Split this batch into smaller chunks. Args: chunks: number of chunks @@ -960,7 +871,6 @@ def chunk(self, chunks: int) -> list["KVBatchMeta"]: Return: List of smaller KVBatchMeta chunks """ - chunk_list = [] if self.size < chunks: logger.warning( @@ -974,7 +884,6 @@ def chunk(self, chunks: int) -> list["KVBatchMeta"]: start = 0 for i in range(chunks): - # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size) current_chunk_size = base_size + 1 if i < remainder else base_size end = start + current_chunk_size chunk_keys = self.keys[start:end] @@ -994,8 +903,7 @@ def chunk(self, chunks: int) -> list["KVBatchMeta"]: @classmethod def concat(cls, data: list["KVBatchMeta"]) -> "KVBatchMeta": - """ - Concatenate multiple KVBatchMeta chunks into one large batch. + """Concatenate multiple KVBatchMeta chunks into one large batch. Args: data: List of KVBatchMeta chunks to concatenate @@ -1027,7 +935,7 @@ def concat(cls, data: list["KVBatchMeta"]) -> "KVBatchMeta": all_keys = [] all_tags = [] - all_extra_info = {} + for chunk in data: if chunk.fields is not None and set(chunk.fields) != base_fields_set: raise ValueError("Field names do not match for concatenation.") @@ -1036,8 +944,33 @@ def concat(cls, data: list["KVBatchMeta"]) -> "KVBatchMeta": all_keys.extend(chunk.keys) all_tags.extend(chunk.tags) - if chunk.extra_info is not None: - all_extra_info.update(chunk.extra_info) + + # Merge extra_info with conflict detection + all_extra_keys: set[str] = set() + for chunk in data: + if chunk.extra_info: + all_extra_keys.update(chunk.extra_info.keys()) + + all_extra_info = {} + if all_extra_keys: + base_info_keys = set(data[0].extra_info.keys()) if data[0].extra_info else set() + for chunk in data[1:]: + chunk_keys = set(chunk.extra_info.keys()) if chunk.extra_info else set() + if chunk_keys != base_info_keys: + logger.warning( + "KVBatchMeta.concat: extra_info key sets differ across chunks, performing union of keys." + ) + break + + for key in all_extra_keys: + values = [chunk.extra_info[key] for chunk in data if chunk.extra_info and key in chunk.extra_info] + first = values[0] + for i, v in enumerate(values[1:], start=1): + if not _extra_info_values_equal(first, v): + raise ValueError( + f"KVBatchMeta.concat: extra_info key '{key}' has conflicting values across chunks." + ) + all_extra_info[key] = first return KVBatchMeta( keys=all_keys, diff --git a/transfer_queue/storage/__init__.py b/transfer_queue/storage/__init__.py index 04b07452..6fbd3415 100644 --- a/transfer_queue/storage/__init__.py +++ b/transfer_queue/storage/__init__.py @@ -21,12 +21,11 @@ TransferQueueStorageManagerFactory, YuanrongStorageManager, ) -from .simple_backend import SimpleStorageUnit, StorageMetaGroup, StorageUnitData +from .simple_backend import SimpleStorageUnit, StorageUnitData __all__ = [ "SimpleStorageUnit", "StorageUnitData", - "StorageMetaGroup", "TransferQueueStorageManager", "TransferQueueStorageManagerFactory", "AsyncSimpleStorageManager", diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index cb714c32..37b4b3fb 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -14,7 +14,6 @@ # limitations under the License. import asyncio -import copy import itertools import logging import os @@ -122,7 +121,6 @@ def _do_handshake_with_controller(self) -> None: ) poller.register(self.controller_handshake_socket, zmq.POLLIN) - # Initial handshake request send self._send_handshake_requests() start_time = time.time() @@ -132,7 +130,6 @@ def _do_handshake_with_controller(self) -> None: not is_connected # Only one controller to connect to and time.time() - start_time < TQ_STORAGE_HANDSHAKE_TIMEOUT ): - # Check for timeout and retransmission current_time = time.time() if pending_connection: if ( @@ -214,6 +211,7 @@ async def notify_data_update( shapes: Per-field shapes for each field, in {global_index: {field: shape}} format. custom_backend_meta: Per-field custom_meta for each sample, in {global_index: {field: custom_meta}} format. """ + if not self.controller_info: logger.warning(f"No controller connected for storage manager {self.storage_manager_id}") return @@ -321,6 +319,41 @@ async def clear_data(self, metadata: BatchMeta) -> None: """ raise NotImplementedError("Subclasses must implement clear_data") + @staticmethod + def _extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: + """Extract field-level schema from TensorDict. O(F) complexity.""" + field_schema: dict[str, dict[str, Any]] = {} + + for field_name in data.keys(): + field_data = data[field_name] + + is_tensor = isinstance(field_data, torch.Tensor) + is_nested = is_tensor and field_data.is_nested + + if is_nested: + unbound = field_data.unbind() + first_item = unbound[0] if unbound else None + elif is_tensor: + first_item = field_data[0] if field_data.shape[0] > 0 else None + else: + first_item = field_data[0] if len(field_data) > 0 else None + + is_non_tensor = not isinstance(first_item, torch.Tensor) if first_item is not None else False + + field_meta = { + "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), + "shape": getattr(first_item, "shape", None) if is_tensor and not is_nested else None, + "is_nested": is_nested, + "is_non_tensor": is_non_tensor, + } + + if is_nested: + field_meta["per_sample_shapes"] = [tuple(t.shape) for t in unbound] + + field_schema[field_name] = field_meta + + return field_schema + def close(self) -> None: """Close all ZMQ sockets and context to prevent resource leaks.""" # Close handshake socket if it exists @@ -362,7 +395,6 @@ def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): super().__init__(controller_info, config) self.storage_client = StorageClientFactory.create(client_name, config) self._multi_threads_executor: Optional[ThreadPoolExecutor] = None - # Register a cleanup function: automatically invoke shutdown when the instance is garbage collected. self._executor_finalizer = weakref.finalize(self, self._shutdown_executor, self._multi_threads_executor) @staticmethod @@ -502,7 +534,6 @@ def process_field(field_idx: int): # Prioritize processing fields with larger tensor sizes to improve parallel efficiency field_sizes = [] for i in range(num_fields): - # Estimate size based on the first value _first_value = values[i * num_samples] if isinstance(_first_value, torch.Tensor): size = _first_value.nelement() * _first_value.element_size() @@ -532,30 +563,31 @@ def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): shapes = [] dtypes = [] custom_backend_meta_list = [] - all_custom_backend_meta = copy.deepcopy(metadata._custom_backend_meta) + num_samples = len(metadata) + for field_name in sorted(metadata.field_names): - for index in range(len(metadata)): - field = metadata.samples[index].get_field_by_name(field_name) - assert field is not None, f"Field {field_name} not found in sample {index}" - shapes.append(field.shape) - dtypes.append(field.dtype) - global_index = metadata.global_indexes[index] - custom_backend_meta_list.append(all_custom_backend_meta.get(global_index, {}).get(field_name, None)) + field_meta = metadata.field_schema.get(field_name, {}) + field_shape = field_meta.get("shape") + field_dtype = field_meta.get("dtype") + per_sample_shapes = field_meta.get("per_sample_shapes") + + for index in range(num_samples): + if per_sample_shapes is not None: + shapes.append(per_sample_shapes[index]) + else: + shapes.append(field_shape) + dtypes.append(field_dtype) + custom_backend_meta_list.append(metadata._custom_backend_meta[index].get(field_name, None)) return shapes, dtypes, custom_backend_meta_list async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: """ Store tensor data in the backend storage and notify the controller. - - Serializes the input tensors, stores them using the storage client, - extracts per-sample dtype and shape information, and sends a notification - to the controller that new data is available. """ if not metadata.field_names: logger.warning("Attempted to put data, but metadata contains no fields.") return - # For each field, extract dtype and shape for each sample num_samples = len(metadata.global_indexes) if num_samples == 0: return @@ -564,37 +596,18 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: values = self._generate_values(data) loop = asyncio.get_event_loop() - # put to storage backends custom_backend_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) - per_field_dtypes: dict[int, dict[str, Any]] = {} - per_field_shapes: dict[int, dict[str, Any]] = {} + field_schema = self._extract_field_schema(data) - # Initialize the data structure for each global index - for global_idx in metadata.global_indexes: - per_field_dtypes[global_idx] = {} - per_field_shapes[global_idx] = {} - - for field_name, field_data in data.items(): - for i in range(num_samples): - data_item = field_data[i] - global_idx = metadata.global_indexes[i] - per_field_dtypes[global_idx][field_name] = ( - getattr(data_item, "dtype", None) if isinstance(data_item, Tensor) else None - ) - per_field_shapes[global_idx][field_name] = ( - getattr(data_item, "shape", None) if isinstance(data_item, Tensor) else None - ) - - # Prepare per-field custom_backend_meta if available per_field_custom_backend_meta: dict[int, dict[str, Any]] = {} if custom_backend_meta: if len(custom_backend_meta) != len(keys): raise ValueError( f"Length of custom_backend_meta ({len(custom_backend_meta)}) does not match expected ({len(keys)})" ) - # custom meta is a flat list aligned with keys/values - # Use itertools.product to eliminate nested loops + global_index_to_position = {global_index: i for i, global_index in enumerate(metadata.global_indexes)} + for global_idx in metadata.global_indexes: per_field_custom_backend_meta[global_idx] = {} @@ -607,13 +620,34 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: strict=True, ): per_field_custom_backend_meta[global_idx][field_name] = meta_value - metadata._custom_backend_meta.update(per_field_custom_backend_meta) + metadata._custom_backend_meta[global_index_to_position[global_idx]][field_name] = meta_value # Get current data partition id - # Note: Currently we only support putting to & getting data from a single data partition simultaneously, - # but in the future we may support putting to & getting data from multiple data partitions concurrently. - partition_id = metadata.samples[0].partition_id - # notify controller that new data is ready + partition_id = metadata.partition_ids[0] + + per_field_dtypes: dict[int, dict[str, Any]] = {} + per_field_shapes: dict[int, dict[str, Any]] = {} + for field_name, field_meta in field_schema.items(): + is_nested = field_meta.get("is_nested", False) + + if is_nested: + per_sample_shapes = field_meta.get("per_sample_shapes", []) + for i, global_idx in enumerate(metadata.global_indexes): + if global_idx not in per_field_dtypes: + per_field_dtypes[global_idx] = {} + per_field_shapes[global_idx] = {} + per_field_dtypes[global_idx][field_name] = field_meta.get("dtype") + per_field_shapes[global_idx][field_name] = ( + per_sample_shapes[i] if i < len(per_sample_shapes) else None + ) + else: + for global_idx in metadata.global_indexes: + if global_idx not in per_field_dtypes: + per_field_dtypes[global_idx] = {} + per_field_shapes[global_idx] = {} + per_field_dtypes[global_idx][field_name] = field_meta.get("dtype") + per_field_shapes[global_idx][field_name] = field_meta.get("shape") + await self.notify_data_update( partition_id, list(data.keys()), diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 608a0827..829ebed1 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -17,6 +17,7 @@ import logging import os import warnings +from collections import defaultdict from collections.abc import Mapping from functools import wraps from operator import itemgetter @@ -27,13 +28,10 @@ import zmq from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict -from torch import Tensor from transfer_queue.metadata import BatchMeta from transfer_queue.storage.managers.base import TransferQueueStorageManager from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory -from transfer_queue.storage.simple_backend import StorageMetaGroup -from transfer_queue.utils.common import get_env_bool from transfer_queue.utils.zmq_utils import ( ZMQMessage, ZMQRequestType, @@ -53,8 +51,6 @@ TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds -TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) - @TransferQueueStorageManagerFactory.register("SimpleStorage") class AsyncSimpleStorageManager(TransferQueueStorageManager): @@ -83,7 +79,6 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): raise ValueError("AsyncSimpleStorageManager requires non-empty 'zmq_info' in config.") self.storage_unit_infos = self._register_servers(server_infos) - self._build_storage_mapping_functions() def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerInfo]"): """Register and validate server information. @@ -112,16 +107,6 @@ def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerIn return server_infos_transform - def _build_storage_mapping_functions(self): - """Build mapping functions for global index to storage unit and local index. - - Creates round-robin mapping functions to distribute data across storage units. - """ - self.global_index_storage_unit_mapping = lambda x: list(self.storage_unit_infos.keys())[ - x % len(self.storage_unit_infos) - ] - self.global_index_local_index_mapping = lambda x: x // len(self.storage_unit_infos) - # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. @staticmethod def dynamic_storage_manager_socket(socket_name: str, timeout: int): @@ -174,7 +159,9 @@ async def wrapper(self, *args, **kwargs): return await func(self, *args, **kwargs) except Exception as e: logger.error( - f"[{self.storage_manager_id}]: Error in socket operation with StorageUnit {server_info.id}: {e}" + f"[{self.storage_manager_id}]: Error in socket operation with " + f"StorageUnit {server_info.id} at {address}: " + f"{type(e).__name__}: {e}" ) raise finally: @@ -192,10 +179,29 @@ async def wrapper(self, *args, **kwargs): return decorator + def _group_by_hash(self, global_indexes: list[int]) -> dict[str, list[int]]: + """Group samples by global_idx % num_su, return {storage_id: [global_indexes]}. + + Routing depends solely on global_idx, independent of batch_size, key ordering, + or number of calls. The same global_idx always routes to the same SU across + put/get/clear operations. + + NOTE: Dynamic SU scaling requires a data migration mechanism (not yet supported). + """ + storage_unit_keys = list(self.storage_unit_infos.keys()) + num_units = len(storage_unit_keys) + groups: dict[str, list[int]] = defaultdict(list) + for global_idx in global_indexes: + groups[storage_unit_keys[global_idx % num_units]].append(global_idx) + return dict(groups) + async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: """ Send data to remote StorageUnit based on metadata. + Routes each sample to its target SU using global_idx % num_su (hash routing). + Complexity: O(F) for schema extraction + O(S) for data distribution. + Args: data: TensorDict containing the data to store. metadata: BatchMeta containing storage location information. @@ -203,62 +209,84 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: logger.debug(f"[{self.storage_manager_id}]: receive put_data request, putting {metadata.size} samples.") - # group samples by storage unit - storage_meta_groups = build_storage_meta_groups( - metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping - ) + batch_size = metadata.size - # unbind nested tensor - results: dict = {} - for field in data.keys(): - field_data = data[field] - if data.batch_size[0] > 1 and isinstance(field_data, Tensor) and field_data.is_nested: - results[field] = field_data.unbind() - else: - results[field] = field_data + if batch_size == 0: + return + + field_schema = self._extract_field_schema(data) - # send data to each storage unit + storage_unit_to_global_indexes = self._group_by_hash(metadata.global_indexes) + # Build global_idx -> batch position mapping for non-contiguous slicing + gi_to_pos = {gi: pos for pos, gi in enumerate(metadata.global_indexes)} tasks = [ - self._put_to_single_storage_unit( - meta_group.get_local_indexes(), - _filter_storage_data(meta_group, results), - target_storage_unit=storage_id, + self._prepare_and_send_to_unit_by_positions( + storage_id=su_id, + positions=[gi_to_pos[gi] for gi in gi_list], + data=data, + metadata=metadata, ) - for storage_id, meta_group in storage_meta_groups.items() + for su_id, gi_list in storage_unit_to_global_indexes.items() ] - await asyncio.gather(*tasks) - - # Gather per-field dtype and shape information for each field - # global_indexes, local_indexes, and field_data correspond one-to-one - per_field_dtypes: dict[int, dict[str, Any]] = {} - per_field_shapes: dict[int, dict[str, Any]] = {} - - # Initialize the data structure for each global index - for global_idx in metadata.global_indexes: - per_field_dtypes[global_idx] = {} - per_field_shapes[global_idx] = {} - - # For each field, extract dtype and shape for each sample - for field in results.keys(): - for i, data_item in enumerate(results[field]): - global_idx = metadata.global_indexes[i] - per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None - per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None - - # Get current data partition id - # Note: Currently we only support putting to & getting data from a single data partition simultaneously, - # but in the future we may support putting to & getting data from multiple data partitions concurrently. - partition_id = metadata.samples[0].partition_id - - # notify controller that new data is ready + + try: + await asyncio.gather(*tasks) + except Exception as e: + logger.error( + f"[{self.storage_manager_id}]: put_data failed. " + f"partition_id={metadata.partition_ids[0]}, " + f"num_samples={metadata.size}, " + f"storage_units={list(storage_unit_to_global_indexes.keys())}, " + f"error={type(e).__name__}: {e}" + ) + raise + + partition_id = metadata.partition_ids[0] + dtypes_for_notify = { + global_index: {field_name: field_meta.get("dtype") for field_name, field_meta in field_schema.items()} + for global_index in metadata.global_indexes + } + shapes_for_notify = { + global_index: {field_name: field_meta.get("shape") for field_name, field_meta in field_schema.items()} + for global_index in metadata.global_indexes + } await self.notify_data_update( - partition_id, list(results.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes + partition_id, + list(data.keys()), + metadata.global_indexes, + dtypes_for_notify, + shapes_for_notify, ) + async def _prepare_and_send_to_unit_by_positions( + self, + storage_id, + positions, + data, + metadata, + ) -> None: + """Slice data by non-contiguous positions and send to the specified SU.""" + global_indexes = [metadata.global_indexes[pos] for pos in positions] + storage_data = {} + for field_name in data.keys(): + field_data = data[field_name] + if isinstance(field_data, torch.Tensor) and field_data.is_nested: + unbound = field_data.unbind() + storage_data[field_name] = [unbound[pos] for pos in positions] + elif isinstance(field_data, NonTensorStack): + items = field_data.tolist() + storage_data[field_name] = NonTensorStack(*[items[pos] for pos in positions]) + elif isinstance(field_data, list): + storage_data[field_name] = [field_data[pos] for pos in positions] + else: + # torch.Tensor (non-nested) and numpy arrays support fancy indexing + storage_data[field_name] = field_data[positions] + await self._put_to_single_storage_unit(global_indexes, storage_data, target_storage_unit=storage_id) + @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) async def _put_to_single_storage_unit( self, - local_indexes: list[int], + global_indexes: list[int], storage_data: dict[str, Any], target_storage_unit: str, socket: zmq.Socket = None, @@ -271,7 +299,7 @@ async def _put_to_single_storage_unit( request_type=ZMQRequestType.PUT_DATA, # type: ignore[arg-type] sender_id=self.storage_manager_id, receiver_id=target_storage_unit, - body={"local_indexes": local_indexes, "data": storage_data}, + body={"global_indexes": global_indexes, "data": storage_data}, ) try: @@ -285,13 +313,29 @@ async def _put_to_single_storage_unit( f"Failed to put data to storage unit {target_storage_unit}: " f"{response_msg.body.get('message', 'Unknown error')}" ) + except zmq.error.Again as e: + timeout_sec = TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT + logger.error( + f"[{self.storage_manager_id}]: ZMQ recv timeout ({timeout_sec}s) " + f"during put to storage unit {target_storage_unit}. " + f"The storage unit may be overloaded or crashed." + ) + raise RuntimeError( + f"ZMQ recv timeout ({timeout_sec}s) during put to storage unit {target_storage_unit}" + ) from e except Exception as e: - raise RuntimeError(f"Error in put to storage unit {target_storage_unit}: {str(e)}") from e + logger.error( + f"[{self.storage_manager_id}]: Unexpected error during put to storage unit " + f"{target_storage_unit}: {type(e).__name__}: {e}" + ) + raise RuntimeError(f"Error in put to storage unit {target_storage_unit}: {type(e).__name__}: {e}") from e async def get_data(self, metadata: BatchMeta) -> TensorDict: """ Retrieve data from remote StorageUnit based on metadata. + Routes to each SU using global_idx % num_su (hash routing). + Args: metadata: BatchMeta that contains metadata for data retrieval. @@ -301,20 +345,27 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: logger.debug(f"[{self.storage_manager_id}]: receive get_data request, getting {metadata.size} samples.") - # group samples by storage unit - storage_meta_groups = build_storage_meta_groups( - metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping - ) + if metadata.size == 0: + return TensorDict({}, batch_size=0) + + storage_unit_groups = self._group_by_hash(metadata.global_indexes) - # retrieve data tasks = [ - self._get_from_single_storage_unit(meta_group, target_storage_unit=storage_id) - for storage_id, meta_group in storage_meta_groups.items() + self._get_from_single_storage_unit(global_indexes, metadata.field_names, target_storage_unit=su_id) + for su_id, global_indexes in storage_unit_groups.items() ] + try: + results = await asyncio.gather(*tasks) + except Exception as e: + logger.error( + f"[{self.storage_manager_id}]: get_data failed. " + f"partition_id={metadata.partition_ids[0]}, " + f"num_samples={metadata.size}, " + f"storage_units={list(storage_unit_groups.keys())}, " + f"error={type(e).__name__}: {e}" + ) + raise - results = await asyncio.gather(*tasks) - - # post-process data segments to generate a batch of data merged_data: dict[int, dict[str, torch.Tensor]] = {} for global_indexes, fields, data_from_single_storage_unit, messages in results: field_getter = itemgetter(*fields) @@ -335,7 +386,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: ordered_data[field] = [merged_data[global_idx][field] for global_idx in metadata.global_indexes] # In the final packing stage we intentionally perform a memory copy through torch.stack and as_nested_tensor. - # This detaches the received tensors from the original zero‑copy buffers, + # This detaches the received tensors from the original zero-copy buffers, # gives them their own lifetime, and ensures the resulting tensors are writable. tensor_data = { field: ( @@ -356,17 +407,18 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) async def _get_from_single_storage_unit( - self, storage_meta_group: StorageMetaGroup, target_storage_unit: str, socket: zmq.Socket = None + self, + global_indexes: list[int], + fields: list[str], + target_storage_unit: str, + socket: zmq.Socket = None, ): - global_indexes = storage_meta_group.get_global_indexes() - local_indexes = storage_meta_group.get_local_indexes() - fields = storage_meta_group.get_field_names() - + """Get data from a single SU by global index keys.""" request_msg = ZMQMessage.create( request_type=ZMQRequestType.GET_DATA, # type: ignore[arg-type] sender_id=self.storage_manager_id, receiver_id=target_storage_unit, - body={"local_indexes": local_indexes, "fields": fields}, + body={"global_indexes": global_indexes, "fields": fields}, ) try: await socket.send_multipart(request_msg.serialize()) @@ -374,9 +426,6 @@ async def _get_from_single_storage_unit( response_msg = ZMQMessage.deserialize(messages) if response_msg.request_type == ZMQRequestType.GET_DATA_RESPONSE: - # Return data and index information from this storage unit - # We need to return messages to get_data() since the zero-copy deserialization directly points to the - # memory of messages object. storage_unit_data = response_msg.body["data"] return global_indexes, fields, storage_unit_data, messages else: @@ -384,27 +433,42 @@ async def _get_from_single_storage_unit( f"Failed to get data from storage unit {target_storage_unit}: " f"{response_msg.body.get('message', 'Unknown error')}" ) + except zmq.error.Again as e: + timeout_sec = TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT + logger.error( + f"[{self.storage_manager_id}]: ZMQ recv timeout ({timeout_sec}s) " + f"from storage unit {target_storage_unit}. " + f"The storage unit may be overloaded or crashed." + ) + raise RuntimeError(f"ZMQ recv timeout ({timeout_sec}s) from storage unit {target_storage_unit}") from e except Exception as e: - raise RuntimeError(f"Error getting data from storage unit {target_storage_unit}: {str(e)}") from e + logger.error( + f"[{self.storage_manager_id}]: Unexpected error from storage unit " + f"{target_storage_unit}: {type(e).__name__}: {e}" + ) + raise RuntimeError( + f"Error getting data from storage unit {target_storage_unit}: {type(e).__name__}: {e}" + ) from e async def clear_data(self, metadata: BatchMeta) -> None: """Clear data in remote StorageUnit. + Routes to each SU using global_idx % num_su (hash routing). + Args: metadata: BatchMeta that contains metadata for data clearing. """ logger.debug(f"[{self.storage_manager_id}]: receive clear_data request, clearing {metadata.size} samples.") - # group samples by storage unit - storage_meta_groups = build_storage_meta_groups( - metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping - ) + if metadata.size == 0: + return + + storage_unit_groups = self._group_by_hash(metadata.global_indexes) - # clear data tasks = [ - self._clear_single_storage_unit(meta_group.get_local_indexes(), target_storage_unit=storage_id) - for storage_id, meta_group in storage_meta_groups.items() + self._clear_single_storage_unit(global_indexes, target_storage_unit=su_id) + for su_id, global_indexes in storage_unit_groups.items() ] results = await asyncio.gather(*tasks, return_exceptions=True) @@ -414,13 +478,13 @@ async def clear_data(self, metadata: BatchMeta) -> None: logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}") @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) - async def _clear_single_storage_unit(self, local_indexes, target_storage_unit=None, socket=None): + async def _clear_single_storage_unit(self, global_indexes, target_storage_unit=None, socket=None): try: request_msg = ZMQMessage.create( request_type=ZMQRequestType.CLEAR_DATA, sender_id=self.storage_manager_id, receiver_id=target_storage_unit, - body={"local_indexes": local_indexes}, + body={"global_indexes": global_indexes}, ) await socket.send_multipart(request_msg.serialize()) @@ -448,115 +512,3 @@ def get_zmq_server_info(self) -> dict[str, ZMQServerInfo]: def close(self) -> None: """Close all ZMQ sockets and context to prevent resource leaks.""" super().close() - - -def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: dict) -> dict[str, Any]: - """Filter batch-aligned data from a dict using batch indexes from a StorageMetaGroup. - This helper extracts a subset of items from each field in ``data`` according to the - batch indexes stored in ``storage_meta_group``. The same indexes are applied to every - field in the input dict so that the returned samples remain aligned across - fields. - - Args: - storage_meta_group: A :class:`StorageMetaGroup` instance that provides - a sequence of batch indexes via :meth:`get_batch_indexes`. Each index - refers to a position along the batch dimension of the tensors stored - in ``data``. - data: A dict containing batched data fields. All fields are expected to - be indexable by the batch indexes returned by - ``storage_meta_group.get_batch_indexes()``. - Returns: - dict[str, Any]: A dictionary mapping each field name in ``data`` to a list - of items selected at the requested batch indexes. The order of items in - each list matches the order of ``storage_meta_group.get_batch_indexes()``. - """ - - # We use dict here instead of TensorDict to avoid unnecessary TensorDict overhead - results: dict[str, Any] = {} - batch_indexes = storage_meta_group.get_batch_indexes() - - if not batch_indexes: - return results - - for fname in data.keys(): - field_data = data[fname] - result = itemgetter(*batch_indexes)(field_data) - - if not isinstance(result, tuple): - result = (result,) - results[fname] = list(result) - - if not TQ_ZERO_COPY_SERIALIZATION: - # Explicitly copy tensor slices to prevent pickling the whole tensor for every storage unit. - # The tensors may still be contiguous, so we cannot use .contiguous() to trigger copy from parent tensors. - results[fname] = [item.clone() if isinstance(item, torch.Tensor) else item for item in results[fname]] - return results - - -def build_storage_meta_groups( - batch_meta: BatchMeta, - global_index_storage_unit_mapping: Callable, - global_index_local_index_mapping: Callable, -) -> dict[str, StorageMetaGroup]: - """Build storage meta groups from batch metadata for distributed storage. - - This function is the starting point of the data distribution workflow. It analyzes - BatchMeta containing SampleMeta objects (originating from client requests) and - groups them by target storage unit based on their global_index. - - Key Data Flow: - 1. BatchMeta contains SampleMeta objects with batch_index (original TensorDict position) - 2. Each SampleMeta is assigned to a storage unit using global_index mapping - 3. Local storage positions are calculated for each sample - 4. Results in StorageMetaGroup objects ready for transfer operations - - Args: - batch_meta: BatchMeta containing SampleMeta objects from client request. - Each SampleMeta has: - - batch_index: Position in original TensorDict (0-based) - - global_index: Global unique identifier across all storage - global_index_storage_unit_mapping: Function to map global_index to storage_unit_id. - Example: lambda x: storage_unit_ids[x % num_storage_units] (round-robin distribution) - global_index_local_index_mapping: Function to map global_index to local_index. - Example: lambda x: x // num_storage_units (local position within storage unit) - - Returns: - Dictionary mapping storage_unit_id to StorageMetaGroup, where each group contains: - - storage_id: Target storage unit identifier - - sample_metas: List of SampleMeta objects assigned to this unit - - local_indexes: List of storage positions for each sample - - Example: - >>> # Input: BatchMeta with samples at global_indexes [10, 11, 12] - >>> # 3 storage units available: storage_0, storage_1, storage_2 - >>> batch_meta = BatchMeta(samples=[ - ... SampleMeta(batch_index=0, global_index=10), # Original position 0 - ... SampleMeta(batch_index=1, global_index=11), # Original position 1 - ... SampleMeta(batch_index=2, global_index=12) # Original position 2 - ... ]) - >>> groups = build_storage_meta_groups( - ... batch_meta, - ... lambda x: f"storage_{x % 3}", # 10->storage_1, 11->storage_2, 12->storage_0 - ... lambda x: x // 3 # 10->3, 11->3, 12->4 - ... ) - >>> groups["storage_1"].sample_metas[0].batch_index # 0 - original TensorDict position - >>> groups["storage_1"].sample_metas[0].local_index # 3 - storage position - - Note: - This function preserves the crucial batch_index information that links each - SampleMeta back to its original position in the client's TensorDict. - This batch_index is later used by _add_field_data() to extract - the correct data items for storage. - """ - storage_meta_groups: dict[str, StorageMetaGroup] = {} - - for sample in batch_meta.samples: - storage_id = global_index_storage_unit_mapping(sample.global_index) - local_index = global_index_local_index_mapping(sample.global_index) - if storage_id not in storage_meta_groups: - storage_meta_groups[storage_id] = StorageMetaGroup(storage_id=storage_id) - - # Use add_sample_meta to store SampleMeta references directly - storage_meta_groups[storage_id].add_sample_meta(sample, local_index) - - return storage_meta_groups diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index d84abe3b..d139f6b2 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -13,13 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses import logging import os import time import weakref -from dataclasses import dataclass -from operator import itemgetter from threading import Event, Thread from typing import Any, Optional from uuid import uuid4 @@ -27,7 +24,6 @@ import ray import zmq -from transfer_queue.metadata import SampleMeta from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads from transfer_queue.utils.enum_utils import TransferQueueRole from transfer_queue.utils.perf_utils import IntervalPerfMonitor @@ -57,100 +53,82 @@ class StorageUnitData: """Storage unit for managing 2D data structure (samples × fields). - This class provides efficient storage and retrieval of data in a 2D matrix format - where rows represent samples (indexed by local_index) and columns represent fields. - Each field contains a list of data items indexed by their local position. + Uses dict-based storage keyed by global_index instead of pre-allocated list. + This allows O(1) insert/delete without index translation and avoids capacity bloat. Data Structure Example: - ┌─────────────┬─────────────┬─────────────┬─────────┐ - │ local_index │ field_name1 │ field_name2 │ ... │ - ├─────────────┼─────────────┼─────────────┼─────────┤ - │ 0 │ item1 │ item2 │ ... │ - │ 1 │ item3 │ item4 │ ... │ - │ 2 │ item5 │ item6 │ ... │ - └─────────────┴─────────────┴─────────────┴─────────┘ + field_data = { + "field_name1": {global_index_0: item1, global_index_3: item2, ...}, + "field_name2": {global_index_0: item3, global_index_3: item4, ...}, + } """ def __init__(self, storage_size: int): - # Dict containing field names and corresponding data in the field - # Format: {"field_name": [data_at_index_0, data_at_index_1, ...]} - self.field_data: dict[str, list] = {} - - # Maximum number of elements stored in storage unit + # field_name -> {global_index: data} nested dict + self.field_data: dict[str, dict] = {} + # Capacity upper bound (not pre-allocated list length) self.storage_size = storage_size + # Track active global_index keys for O(1) capacity checks + self._active_keys: set = set() - def get_data(self, fields: list[str], local_indexes: list[int]) -> dict[str, list]: - """ - Get data from storage unit according to given fields and local_indexes. + def get_data(self, fields: list[str], global_indexes: list) -> dict[str, list]: + """Get data by global index keys. Args: fields: Field names used for getting data. - local_indexes: Local indexes used for getting data. + global_indexes: Global indexes used as dict keys. Returns: dict with field names as keys, corresponding data list as values. """ result: dict[str, list] = {} - for field in fields: - # Validate field name if field not in self.field_data: raise ValueError( - f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}" + f"StorageUnitData get_data: field '{field}' not found. Available: {list(self.field_data.keys())}" ) - - if len(local_indexes) == 1: - gathered_item = self.field_data[field][local_indexes[0]] - result[field] = [gathered_item] - - else: - gathered_items = list(itemgetter(*local_indexes)(self.field_data[field])) - - result[field] = gathered_items - + try: + result[field] = [self.field_data[field][k] for k in global_indexes] + except KeyError as e: + raise KeyError(f"StorageUnitData get_data: key {e} not found in field '{field}'") from e return result - def put_data(self, field_data: dict[str, Any], local_indexes: list[int]) -> None: - """ - Put or update data into storage unit according to given field_data and local_indexes. + def put_data(self, field_data: dict[str, Any], global_indexes: list) -> None: + """Put data into storage. Args: - field_data: Dict with field names as keys, corresponding data in the field as values. - local_indexes: Local indexes used for putting data. + field_data: Dict with field names as keys, data list as values. + global_indexes: Global indexes to use as dict keys. """ - + # Capacity is enforced per unique sample key, not counted per-field + new_global_keys = [k for k in global_indexes if k not in self._active_keys] + if len(self._active_keys) + len(new_global_keys) > self.storage_size: + raise ValueError( + f"Storage capacity exceeded: {len(self._active_keys)} existing + " + f"{len(new_global_keys)} new > {self.storage_size}" + ) for f, values in field_data.items(): + if len(values) != len(global_indexes): + raise ValueError( + f"StorageUnitData put_data: field '{f}' values length {len(values)} " + f"!= global_indexes length {len(global_indexes)}, length mismatch" + ) if f not in self.field_data: - self.field_data[f] = [None] * self.storage_size + self.field_data[f] = {} + for key, val in zip(global_indexes, values, strict=True): + self.field_data[f][key] = val + self._active_keys.update(global_indexes) - for i, idx in enumerate(local_indexes): - if idx < 0 or idx >= self.storage_size: - raise ValueError( - f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond " - f"storage_size: {self.storage_size}" - ) - - self.field_data[f][idx] = values[i] - - def clear(self, local_indexes: list[int]) -> None: - """ - Clear data at specified local_indexes by setting all related fields to None. + def clear(self, keys: list[int]) -> None: + """Remove data at given global index keys, immediately freeing memory. Args: - local_indexes: local_indexes to clear. + keys: Global indexes to remove. """ - # Validate local_indexes - for idx in local_indexes: - if idx < 0 or idx >= self.storage_size: - raise ValueError( - f"StorageUnitData clear operation receive invalid local_index: {idx} beyond " - f"storage_size: {self.storage_size}" - ) - - # Clear data at specified local_indexes for f in self.field_data: - for idx in local_indexes: - self.field_data[f][idx] = None + for key in keys: + self.field_data[f].pop(key, None) + self._active_keys -= set(keys) @ray.remote(num_cpus=1) @@ -332,6 +310,10 @@ def _worker_routine(self) -> None: }, ) except Exception as e: + logger.error( + f"[{self.storage_unit_id}]: worker error during {operation} " + f"from sender={request_msg.sender_id}: {type(e).__name__}: {e}" + ) response_msg = ZMQMessage.create( request_type=ZMQRequestType.PUT_GET_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, @@ -359,18 +341,18 @@ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: Put data success response ZMQMessage. """ try: - local_indexes = data_parts.body["local_indexes"] + global_indexes = data_parts.body["global_indexes"] field_data = data_parts.body["data"] # field_data should be a TensorDict. with limit_pytorch_auto_parallel_threads( target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_put" ): - self.storage_data.put_data(field_data, local_indexes) + self.storage_data.put_data(field_data, global_indexes) # After put operation finish, send a message to the client response_msg = ZMQMessage.create( request_type=ZMQRequestType.PUT_DATA_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, - body={}, # type: ignore[arg-type] + body={}, ) return response_msg @@ -396,12 +378,12 @@ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: """ try: fields = data_parts.body["fields"] - local_indexes = data_parts.body["local_indexes"] + global_indexes = data_parts.body["global_indexes"] with limit_pytorch_auto_parallel_threads( target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_get" ): - result_data = self.storage_data.get_data(fields, local_indexes) + result_data = self.storage_data.get_data(fields, global_indexes) response_msg = ZMQMessage.create( request_type=ZMQRequestType.GET_DATA_RESPONSE, # type: ignore[arg-type] @@ -411,6 +393,10 @@ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: }, ) except Exception as e: + logger.error( + f"[{self.storage_unit_id}]: _handle_get error, " + f"fields={fields}, global_indexes={global_indexes}: {type(e).__name__}: {e}" + ) response_msg = ZMQMessage.create( request_type=ZMQRequestType.GET_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, @@ -423,21 +409,21 @@ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: """ - Handle clear request, clear data in storage unit according to given local_indexes. + Handle clear request, clear data in storage unit according to given global_indexes. Args: - data_parts: ZMQMessage from client, including target local_indexes. + data_parts: ZMQMessage from client, including target global_indexes. Returns: Clear data success response ZMQMessage. """ try: - local_indexes = data_parts.body["local_indexes"] + global_indexes = data_parts.body["global_indexes"] with limit_pytorch_auto_parallel_threads( target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_clear" ): - self.storage_data.clear(local_indexes) + self.storage_data.clear(global_indexes) response_msg = ZMQMessage.create( request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, # type: ignore[arg-type] @@ -492,60 +478,3 @@ def get_zmq_server_info(self) -> ZMQServerInfo: ZMQServerInfo containing connection details for this storage unit. """ return self.zmq_server_info - - -@dataclass -class StorageMetaGroup: - """ - Represents a group of samples stored in the same storage unit. - Used to organize samples by their storage_id for efficient client operations. - """ - - storage_id: str - sample_metas: list[SampleMeta] = dataclasses.field(default_factory=list) - local_indexes: list[int] = dataclasses.field(default_factory=list) - - def add_sample_meta(self, sample_meta: SampleMeta, local_index: int) -> None: - """Add a SampleMeta object to this storage group""" - self.sample_metas.append(sample_meta) - self.local_indexes.append(local_index) - - def get_batch_indexes(self) -> list[int]: - """Get all internal indexes from stored SampleMeta objects""" - return [meta.batch_index for meta in self.sample_metas] - - def get_global_indexes(self) -> list[int]: - """Get all global indexes from stored SampleMeta objects""" - return [meta.global_index for meta in self.sample_metas] - - def get_local_indexes(self) -> list[int]: - """Get all local indexes from stored SampleMeta objects""" - return self.local_indexes - - def get_field_names(self) -> list[str]: - """Get all unique field names from stored SampleMeta objects""" - all_fields: set[str] = set() - for meta in self.sample_metas: - all_fields.update(meta.fields.keys()) - return list(all_fields) - - @property - def size(self) -> int: - """Number of samples in this storage meta group""" - return len(self.sample_metas) - - @property - def is_empty(self) -> bool: - """Check if this storage meta group is empty""" - return len(self.sample_metas) == 0 - - def __len__(self) -> int: - """Number of samples in this storage meta group""" - return self.size - - def __bool__(self) -> bool: - """Truthiness based on whether group has samples""" - return not self.is_empty - - def __str__(self) -> str: - return f"StorageMetaGroup(storage_id='{self.storage_id}', size={self.size})" diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index aa74009b..38917a11 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -17,6 +17,8 @@ # This implementation is inspired by https://github.com/vllm-project/vllm/blob/main/vllm/v1/serial_utils.py +import logging +import os import pickle import warnings from collections.abc import Sequence @@ -35,9 +37,17 @@ CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_TENSOR = 3 # For tensor with buffer reference CUSTOM_TYPE_NESTED_TENSOR = 4 # For nested tensor (strided or jagged) +CUSTOM_TYPE_BATCHMETA = 5 # For BatchMeta serialization +CUSTOM_TYPE_NUMPY = 6 # For numpy ndarray with buffer reference + +# 0xC1 is permanently reserved (invalid) in msgpack spec — safe to use as pickle fallback sentinel. +_PICKLE_FALLBACK_SENTINEL = b"\xc1\xfe\xed" bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + # Ignore warnings about non-writable buffers from torch.frombuffer. Upper codes will ensure # the tensors are writable to users. warnings.filterwarnings(action="ignore", message=r"The given buffer is not writable*", category=UserWarning) @@ -69,6 +79,9 @@ def aux_buffers(self) -> list[bytestr]: def encode(self, obj: Any) -> Sequence[bytestr]: """Encode a given object to a byte array.""" + # Pre-process to convert BatchMeta to Ext; msgspec auto-serializes dataclasses and won't call enc_hook for them. + obj = self._preprocess_for_batchmeta(obj) + bufs: list[bytestr] = [b""] token = _encoder_aux_buffers.set(bufs) try: @@ -81,6 +94,24 @@ def encode(self, obj: Any) -> Sequence[bytestr]: finally: _encoder_aux_buffers.reset(token) + def _preprocess_for_batchmeta(self, obj: Any) -> Any: + """Recursively preprocess object to convert BatchMeta to Ext. + + This is necessary because msgspec auto-serializes dataclasses and + won't call enc_hook for them. + """ + from transfer_queue.metadata import BatchMeta + + if isinstance(obj, BatchMeta): + return self._encode_batchmeta(obj) + elif isinstance(obj, dict): + return {k: self._preprocess_for_batchmeta(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._preprocess_for_batchmeta(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(self._preprocess_for_batchmeta(item) for item in obj) + return obj + def enc_hook(self, obj: Any) -> Any: """Custom encoding hook for types msgspec doesn't natively support. @@ -88,6 +119,9 @@ def enc_hook(self, obj: Any) -> Any: - torch.Tensor: Extract buffer, store metadata - TensorDict: Convert to dict structure for recursive processing - numpy.ndarray: Convert to tensor for unified handling + + Note: BatchMeta is handled by _preprocess_for_batchmeta() before encode() is called, + so it will never reach this hook. """ if isinstance(obj, torch.Tensor): return self._encode_tensor(obj) @@ -96,17 +130,15 @@ def enc_hook(self, obj: Any) -> Any: if isinstance(obj, TensorDictBase): return self._encode_tensordict(obj) - # Handle numpy arrays by converting to tensor - # Only numeric dtypes are supported by torch.from_numpy: - # f=float, i=signed int, u=unsigned int, b=bool, c=complex + # Numpy arrays: serialize natively unless the dtype contains Python objects. if isinstance(obj, np.ndarray): - if obj.dtype.kind in ("f", "i", "u", "b", "c"): + if obj.dtype.kind != "O" and not obj.dtype.hasobject: try: - return self._encode_tensor(torch.from_numpy(obj)) - except (TypeError, RuntimeError): - # Fallback to pickle for unsupported dtypes (e.g., float16 on some platforms) + return self._encode_numpy(obj) + except (TypeError, RuntimeError, ValueError): + # Fallback to pickle for platforms that don't support the view pass - # For object arrays, strings, or other unsupported types, use pickle + # Only true object arrays (or structured dtypes with object fields) reach here return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) if isinstance(obj, FunctionType): @@ -116,6 +148,15 @@ def enc_hook(self, obj: Any) -> Any: # Fallback to pickle for unknown types return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) + def _encode_batchmeta(self, obj: Any) -> msgpack.Ext: + """Encode BatchMeta as a pickle-based Ext payload. + + BatchMeta must be preprocessed before encode() because msgspec auto-serializes + dataclasses (bypassing enc_hook), and BatchMeta fields contain torch.dtype which + msgpack cannot handle natively. + """ + return msgpack.Ext(CUSTOM_TYPE_BATCHMETA, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) + def _encode_tensordict(self, obj: Any) -> dict: """Convert TensorDict to a dict structure for recursive msgpack processing. @@ -134,16 +175,7 @@ def _encode_tensordict(self, obj: Any) -> dict: } def _encode_tensor(self, obj: torch.Tensor) -> msgpack.Ext: - """Encode tensor with zero-copy buffer extraction. - - Features: - - Auto GPU->CPU conversion - - Auto contiguous conversion - - Direct memoryview extraction via uint8 view (for BFloat16 support) - - Nested tensors: unbind and serialize each sub-tensor with zero-copy - - Returns Ext type so decoding goes through ext_hook (which has buffer access). - """ + """Encode tensor with zero-copy buffer extraction (handles GPU, non-contiguous, nested).""" assert len(self.aux_buffers) > 0 # Handle nested tensors (strided or jagged) via unbind @@ -218,6 +250,20 @@ def _encode_regular_tensor(self, obj: torch.Tensor) -> msgpack.Ext: meta = (dtype, tuple(obj.shape), idx) return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(meta, protocol=pickle.HIGHEST_PROTOCOL)) + def _encode_numpy(self, obj: np.ndarray) -> msgpack.Ext: + """Encode numpy array with zero-copy buffer extraction.""" + # Ensure C-contiguous layout; no-op when already contiguous + if not obj.flags["C_CONTIGUOUS"]: + obj = np.ascontiguousarray(obj) + + # Byte-level view as uint8 then ravel → 1-D C-contiguous raw-bytes array + buf = memoryview(obj.view(np.uint8).ravel()) + idx = len(self.aux_buffers) + self.aux_buffers.append(buf) + + meta = (str(obj.dtype), tuple(obj.shape), idx) + return msgpack.Ext(CUSTOM_TYPE_NUMPY, pickle.dumps(meta, protocol=pickle.HIGHEST_PROTOCOL)) + class MsgpackDecoder: """Decoder with custom torch tensor and numpy array serialization. @@ -307,6 +353,19 @@ def _decode_nested_tensor(self, nested_meta: dict) -> torch.Tensor: else: # strided return torch.nested.as_nested_tensor(sub_tensors, layout=torch.strided) + def _decode_numpy(self, meta: tuple) -> np.ndarray: + """Decode numpy array from (dtype_str, shape, buffer_idx) tuple.""" + dtype_str, shape, idx = meta + buffer = self.aux_buffers[idx] + np_dtype = np.dtype(dtype_str) + + if not buffer: # empty array + return np.empty(shape, dtype=np_dtype) + + # Reconstruct from raw bytes: uint8 view → reinterpret as original dtype + arr = np.frombuffer(buffer, dtype=np.uint8) + return arr.view(np_dtype).reshape(shape) + def ext_hook(self, code: int, data: memoryview) -> Any: """Custom decoding hook for types msgspec doesn't natively support. @@ -314,6 +373,7 @@ def ext_hook(self, code: int, data: memoryview) -> Any: - torch.Tensor: Extract buffer, store metadata - TensorDict: Convert to dict structure for recursive processing - numpy.ndarray: Convert to tensor for unified handling + - BatchMeta: Reconstruct from pickle """ if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) @@ -325,9 +385,41 @@ def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_NESTED_TENSOR: nested_meta = pickle.loads(data) return self._decode_nested_tensor(nested_meta) + if code == CUSTOM_TYPE_BATCHMETA: + return pickle.loads(data) + if code == CUSTOM_TYPE_NUMPY: + meta = pickle.loads(data) + return self._decode_numpy(meta) raise NotImplementedError(f"Extension type code {code} is not supported") _encoder = MsgpackEncoder() _decoder = MsgpackDecoder() + + +def encode(obj: Any) -> list[bytestr]: + """Encode an object via msgpack zero-copy; falls back to pickle on failure. + + The pickle path is a normal degradation path (e.g. body contains torch.dtype + objects). Use this as the single entry point for all ZMQ message serialization. + """ + try: + return list(_encoder.encode(obj)) + except (TypeError, ValueError) as e: + logger.debug( + "encode: msgpack failed (%s), falling back to pickle.", + type(e).__name__, + ) + return [_PICKLE_FALLBACK_SENTINEL, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)] + + +def decode(frames: list) -> Any: + """Decode frames produced by encode. + + Transparently handles both the msgpack zero-copy path and the pickle + fallback path based on the leading sentinel frame. + """ + if len(frames) >= 2 and frames[0] == _PICKLE_FALLBACK_SENTINEL: + return pickle.loads(frames[1]) + return _decoder.decode(frames) diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 7d0d8d18..8afbb480 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -15,7 +15,6 @@ import logging import os -import pickle import socket import time from dataclasses import dataclass @@ -27,11 +26,8 @@ import zmq from ray.util import get_node_ip_address -from transfer_queue.utils.common import ( - get_env_bool, -) from transfer_queue.utils.enum_utils import ExplicitEnum, TransferQueueRole -from transfer_queue.utils.serial_utils import _decoder, _encoder +from transfer_queue.utils.serial_utils import decode, encode logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -45,8 +41,6 @@ bytestr: TypeAlias = bytes | bytearray | memoryview -TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) - class ZMQRequestType(ExplicitEnum): """ @@ -171,43 +165,32 @@ def create( ) def serialize(self) -> list: - """ - Serialize message using unified MsgpackEncoder or pickle. - Returns: list[bytestr] - [msgpack_header, *tensor_buffers] or [bytes] - """ - if TQ_ZERO_COPY_SERIALIZATION: - msg_dict = { - "request_type": self.request_type.value, # Enum -> str for msgpack - "sender_id": self.sender_id, - "receiver_id": self.receiver_id, - "request_id": self.request_id, - "timestamp": self.timestamp, - "body": self.body, - } - return list(_encoder.encode(msg_dict)) - else: - return [pickle.dumps(self)] + """Serialize using zero-copy msgpack; falls back to pickle for unsupported types.""" + msg_dict = { + "request_type": self.request_type.value, # Enum -> str for msgpack + "sender_id": self.sender_id, + "receiver_id": self.receiver_id, + "request_id": self.request_id, + "timestamp": self.timestamp, + "body": self.body, + } + return encode(msg_dict) @classmethod def deserialize(cls, frames: list) -> "ZMQMessage": - """ - Deserialize message using unified MsgpackDecoder or pickle. - """ + """Deserialize: choose decoding path based on the first frame marker (zero-copy or pickle fallback).""" if not frames: raise ValueError("Empty frames received") - if TQ_ZERO_COPY_SERIALIZATION: - msg_dict = _decoder.decode(frames) - return cls( - request_type=ZMQRequestType(msg_dict["request_type"]), - sender_id=msg_dict["sender_id"], - receiver_id=msg_dict["receiver_id"], - body=msg_dict["body"], - request_id=msg_dict["request_id"], - timestamp=msg_dict["timestamp"], - ) - else: - return pickle.loads(frames[0]) + result = decode(frames) + return cls( + request_type=ZMQRequestType(result["request_type"]), + sender_id=result["sender_id"], + receiver_id=result["receiver_id"], + body=result["body"], + request_id=result["request_id"], + timestamp=result["timestamp"], + ) def is_ipv6_address(ip: str) -> bool: diff --git a/tutorial/03_metadata_concepts.py b/tutorial/03_metadata_concepts.py index 2c819416..24c0e4b7 100644 --- a/tutorial/03_metadata_concepts.py +++ b/tutorial/03_metadata_concepts.py @@ -36,6 +36,7 @@ ) +import numpy as np # noqa: E402 import ray # noqa: E402 import torch # noqa: E402 from tensordict import TensorDict # noqa: E402 @@ -45,253 +46,190 @@ sys.path.append(str(parent_dir)) import transfer_queue as tq # noqa: E402 -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 -from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 +from transfer_queue.metadata import BatchMeta # noqa: E402 # Configure Ray os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["RAY_DEBUG"] = "1" -def demonstrate_field_meta(): +def demonstrate_batch_meta_schema(): """ - Demonstrate FieldMeta - specific data fields of each training sample. + Demonstrate BatchMeta basic usage. """ print("=" * 80) - print("FieldMeta - Specific data fields of each training sample") + print("BatchMeta - Fine-Grained Metadata in Field Level") print("=" * 80) - print("FieldMeta represents a single field in ONE sample:") - print("- name: Field identifier ('Prompt', 'Response', etc.)") + print("field_schema stores metadata for each field:") print("- dtype: Data type (torch.float32, torch.int64, etc.)") - print("- shape: Shape of ONE sample's data (NO batch dimension)") - print("- production_status: Whether data is ready (has been produced and written to the TQ backend)") - - # Example 1: Create a field for input_ids - print("[Example 1] Manually creating FieldMeta for input_ids...") - input_ids_field = FieldMeta( - name="input_ids", - dtype=torch.int64, - shape=(512,), # Sequence length for ONE sample - production_status=ProductionStatus.READY_FOR_CONSUME, + print("- shape: Shape of ONE sample's data") + print("- is_nested: Whether the field uses nested/ragged tensors") + print("- is_non_tensor: Whether the field is non-tensor data") + + # Example 1: Create a field schema entry for input_ids + print("[Example 1] Creating field schema entry for input_ids...") + batch = BatchMeta( + global_indexes=[0, 1, 2], + partition_ids=["train_0"] * 3, + field_schema={ + "input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + }, ) - print(f"✓ Created: {input_ids_field}") - print(f" Is ready: {input_ids_field.is_ready}") + print("✓ Created: BatchMeta with field 'input_ids'") + print(f" input_ids schema: {batch.field_schema['input_ids']}") + print(f" Is ready: {batch.is_ready}") print(" Note: Shape (512,) means ONE sample has 512 tokens (no batch dimension)") - # Example 2: Create a field for attention_mask - print("[Example 2] Creating FieldMeta for attention_mask...") - attention_mask_field = FieldMeta( - name="attention_mask", - dtype=torch.int64, - shape=(512,), # Sequence length for ONE sample - production_status=ProductionStatus.NOT_PRODUCED, + # Example 2: Create a field schema entry for attention_mask + print("[Example 2] Creating field schema entry for attention_mask...") + batch2 = BatchMeta( + global_indexes=[0, 1, 2], + partition_ids=["train_0"] * 3, + field_schema={ + "attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + }, ) - print(f"✓ Created: {attention_mask_field}") - print(f" Is ready: {attention_mask_field.is_ready}") + print("✓ Created: BatchMeta with field 'attention_mask'") + print(f" attention_mask schema: {batch2.field_schema['attention_mask']}") + print(f" Is ready: {batch2.is_ready}") - # Example 3: Check field readiness + # Example 3: Check field readiness via is_ready and production_status print("[Example 3] Checking field readiness...") - print(f" input_ids ready: {input_ids_field.is_ready}") - print(f" attention_mask ready: {attention_mask_field.is_ready}") + ready_batch = BatchMeta( + global_indexes=[0, 1, 2], + partition_ids=["train_0"] * 3, + field_schema={ + "input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + }, + production_status=np.array([1, 1, 1], dtype="int8"), # 1 = READY_FOR_CONSUME + ) + print(f" input_ids field exists: {'input_ids' in ready_batch.field_schema}") + print(f" attention_mask field exists: {'attention_mask' in ready_batch.field_schema}") + print(f" not-ready batch is_ready: {batch.is_ready}") + print(f" ready batch is_ready: {ready_batch.is_ready}") + # Example 4: Access per-sample view and individual field schema by key + print("[Example 4] Accessing sample view and individual field by key...") + view = ready_batch.samples[0] + print(f" batch.samples[0].fields -> {view.fields}") + print(f" batch.samples[0].fields['input_ids'] -> {view.fields['input_ids']}") + print(f" batch.samples[0].fields['input_ids']['dtype'] -> {view.fields['input_ids']['dtype']}") -def demonstrate_sample_meta(): + +def demonstrate_batch_meta_operations(): """ - Demonstrate SampleMeta - describes a single data sample. + Demonstrate BatchMeta construction and operations. + Covers: manual creation, add_fields, select_fields, select_samples, + reorder, chunk, concat, union, extra_info, custom_meta. """ print("=" * 80) - print("SampleMeta - Describing a Single Data Sample") + print("BatchMeta - Construction & Operations") print("=" * 80) - print("SampleMeta represents ONE data sample:") - print("- partition_id: Which partition the sample belongs to") - print("- global_index: Unique identifier across ALL partitions") - print("- fields: Dict of FieldMeta objects (describing each field of THIS sample)") - - # Example 1: Manually create a sample - print("[Example 1] Creating a SampleMeta...") - fields = { - "input_ids": FieldMeta("input_ids", torch.int64, (512,)), - "attention_mask": FieldMeta("attention_mask", torch.int64, (512,)), - } - sample = SampleMeta(partition_id="train_0", global_index=0, fields=fields) - print(f"✓ Created: {sample}") - print(f" Partition: {sample.partition_id}") - print(f" Global index: {sample.global_index}") - print(f" Fields: {sample.field_names}") - print(f" Is ready: {sample.is_ready}") - - # Example 2: Manually add fields to a sample - print("[Example 2] Adding fields to a sample...") - new_fields = { - "responses": FieldMeta("responses", torch.int64, (128,)), - "log_probs": FieldMeta("log_probs", torch.float32, (128,)), - } - sample.add_fields(new_fields) - print(f"✓ Added fields: {list(new_fields.keys())}") - print(f" Now has fields: {sample.field_names}") - print(f" Is ready: {sample.is_ready}") - - # Example 3: Select specific fields - print("[Example 3] Selecting specific fields...") - selected_sample = sample.select_fields(["input_ids", "responses"]) - print(f"✓ Selected fields: {selected_sample.field_names}") - print(f" Original fields: {sample.field_names}") - - # Example 4: Union two samples - print("[Example 4] Unioning two samples...") - print(" IMPORTANT: Union requires samples to have IDENTICAL partition_id and global_index!") - sample1 = SampleMeta( - partition_id="train_0", - global_index=5, - fields={ - "input_ids": FieldMeta("input_ids", torch.int64, (512,)), - "attention_mask": FieldMeta("attention_mask", torch.int64, (512,)), + print("BatchMeta uses a columnar layout:") + print("- global_indexes: list[int] - unique IDs across ALL partitions") + print("- partition_ids: list[str] - which partition each sample belongs to") + print("- field_schema: dict[str, dict] - field metadata") + print("- Operations: add_fields, select_fields, select_samples, reorder, chunk, concat, union") + + # Helper to manually create a BatchMeta + def make_batch(global_indexes, fields=None): + if fields is None: + fields = ["input_ids", "attention_mask", "responses"] + schema = { + "input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "responses": {"dtype": torch.int64, "shape": (128,), "is_nested": False, "is_non_tensor": False}, + } + return BatchMeta( + global_indexes=global_indexes, + partition_ids=["train_0"] * len(global_indexes), + field_schema={k: v for k, v in schema.items() if k in fields}, + ) + + # --- 1. Manual Construction --- + print("[Example 1] Creating a BatchMeta with input_ids and attention_mask...") + batch = BatchMeta( + global_indexes=[0, 1, 2, 3, 4], + partition_ids=["train_0"] * 5, + field_schema={ + "input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, }, ) - sample2 = SampleMeta( - partition_id="train_0", - global_index=5, # Same global index! - fields={ - "responses": FieldMeta("responses", torch.int64, (128,)), - "log_probs": FieldMeta("log_probs", torch.float32, (128,)), - }, + print(f"✓ Created: {len(batch)} samples") + print(f" Global indexes: {batch.global_indexes}, Fields: {batch.field_names}") + print(f" Is ready: {batch.is_ready}") + + # --- 2. add_fields --- + print("[Example 2] Adding new fields via add_fields(TensorDict)...") + new_data = TensorDict( + {"responses": torch.randint(0, 1000, (5, 128)), "log_probs": torch.randn(5, 128)}, + batch_size=5, ) - print(f" Sample1: partition={sample1.partition_id}, global_index={sample1.global_index}") - print(f" Sample2: partition={sample2.partition_id}, global_index={sample2.global_index}") - - try: - unioned = sample1.union(sample2) - print("✓ Union successful!") - print(f" Unioned fields: {unioned.field_names}") - print(f" Global index preserved: {unioned.global_index}") - except ValueError as e: - print(f"✗ Union failed: {e}") - - -def demonstrate_batch_meta(): - """ - Demonstrate BatchMeta - describes a batch of samples with operations. - """ - print("=" * 80) - print("BatchMeta - Describing a Batch of Samples") - print("=" * 80) + batch.add_fields(new_data) + print(f"✓ Added fields: ['responses', 'log_probs']. Now has: {batch.field_names}") + print(f" Is ready: {batch.is_ready} (add_fields sets all to READY by default)") - print("BatchMeta represents a collection of samples:") - print("- samples: List of SampleMeta objects") - print("- extra_info: Additional batch-level information") - print("- Provides operations: chunk, concat, union, select, reorder") - - # Example 1: Manually create a batch - print("[Example 1] Creating a BatchMeta...") - fields = { - "input_ids": FieldMeta("input_ids", torch.int64, (512,)), - "attention_mask": FieldMeta("attention_mask", torch.int64, (512,)), - "responses": FieldMeta("responses", torch.int64, (128,)), - } - samples = [SampleMeta(partition_id="train_0", global_index=i, fields=fields) for i in range(5)] - batch = BatchMeta(samples=samples) - print(f"✓ Created batch with {len(batch)} samples") - print(f" Global indexes: {batch.global_indexes}") - print(f" Field names: {batch.field_names}") - print(f" Size: {batch.size}") - - # Example 2: Add extra_info - print("[Example 2] Adding batch-level information through extra_info...") - print("Note: The extra info will not be stored into TransferQueueController.") + # --- 3. extra_info & custom_meta --- + print("[Example 3] Adding batch-level extra_info and sample-level custom_meta...") batch.extra_info["epoch"] = 1 batch.extra_info["batch_idx"] = 0 - print(f"✓ Extra info: {batch.get_all_extra_info()}") - - print("[Example 3] Adding sample-level information through custom_meta...") - batch.update_custom_meta( - [ - {"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"}, - {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"}, - {"uid": "prompt@2", "session_id": "session@0", "model_version": "epoch@0"}, - {"uid": "prompt@3", "session_id": "session@0", "model_version": "epoch@0"}, - {"uid": "prompt@4", "session_id": "session@0", "model_version": "epoch@0"}, - ] - ) - print(f"✓ Custom meta: {batch.get_all_custom_meta()}") - - # Example 4: Chunk a batch - print("[Example 4] Chunking a batch into parts...") - chunks = batch.chunk(3) - print(f"✓ Split into {len(chunks)} chunks:") - for i, chunk in enumerate(chunks): - print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}") + batch.update_custom_meta([{"uid": f"prompt@{i}", "session_id": "session@0"} for i in range(5)]) + print(f" Extra info: {batch.get_all_extra_info()}") + print(f" custom_meta[0]: {batch.custom_meta[0]}") - # Example 5: Select specific fields - print("[Example 5] Selecting specific fields...") - selected_batch = batch.select_fields(["input_ids", "responses"]) - print(f"✓ Selected fields: {selected_batch.field_names}") - print(f" Original fields: {batch.field_names}") + # --- 4. select_fields --- + print("[Example 4] Selecting specific fields...") + selected = batch.select_fields(["input_ids", "responses"]) + print(f"✓ Selected: {selected.field_names} (original: {batch.field_names})") - # Example 6: Select specific samples - print("[Example 6] Selecting specific samples...") + # --- 5. select_samples --- + print("[Example 5] Selecting specific samples...") selected_samples = batch.select_samples([0, 2, 4]) - print(f"✓ Selected samples at indexes: {selected_samples.global_indexes}") + print(f"✓ Selected samples at [0,2,4]: indexes={selected_samples.global_indexes}") - # Example 7: Reorder samples - print("[Example 7] Reordering samples...") - print(f" Original order: {batch.global_indexes}") + # --- 6. reorder --- + print("[Example 6] Reordering samples...") + print(f" Before: {batch.global_indexes}") batch.reorder([4, 3, 2, 1, 0]) - print(f" After reorder: {batch.global_indexes}") + print(f" After: {batch.global_indexes}") - # Example 8: Concat batches + # --- 7. chunk --- + print("[Example 7] Chunking a batch into parts...") + batch_for_chunk = make_batch(list(range(10))) + chunks = batch_for_chunk.chunk(3) + print(f"✓ Split into {len(chunks)} chunks:") + for i, chunk in enumerate(chunks): + print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}") + + # --- 8. concat --- print("[Example 8] Concatenating batches...") - batch1 = BatchMeta(samples=[SampleMeta(partition_id="train_0", global_index=i, fields=fields) for i in range(3)]) - batch2 = BatchMeta(samples=[SampleMeta(partition_id="train_0", global_index=i, fields=fields) for i in range(3, 6)]) + batch1 = make_batch(list(range(3))) + batch2 = make_batch(list(range(3, 6))) concatenated = BatchMeta.concat([batch1, batch2]) print(f"✓ Concatenated {len(batch1)} + {len(batch2)} = {len(concatenated)} samples") print(f" Global indexes: {concatenated.global_indexes}") - print(" Note: concat combines multiple batches into one (same structure)") - - # Example 9: Union batches - print("[Example 9] Unioning batches (different fields, same samples)...") - batch_with_input = BatchMeta( - samples=[ - SampleMeta( - partition_id="train_0", - global_index=i, - fields={ - "input_ids": FieldMeta("input_ids", torch.int64, (512,)), - "attention_mask": FieldMeta("attention_mask", torch.int64, (512,)), - }, - ) - for i in range(3) - ] - ) - batch_with_output = BatchMeta( - samples=[ - SampleMeta( - partition_id="train_0", - global_index=i, - fields={ - "responses": FieldMeta("responses", torch.int64, (128,)), - "log_probs": FieldMeta("log_probs", torch.float32, (128,)), - }, - ) - for i in range(3) - ] - ) - print(f" Batch1 has fields: {batch_with_input.field_names}") - print(f" Batch2 has fields: {batch_with_output.field_names}") - print(f" Both have same samples (global_indexes: {batch_with_input.global_indexes})") - unioned_batch = batch_with_input.union(batch_with_output) - print("✓ Union successful!") - print(f" Unioned fields: {unioned_batch.field_names}") - print(" Note: union merges fields from two batches with SAME samples (same global_indexes)") + # --- 9. union (dedup by global_index) --- + print("[Example 9] Unioning batches with overlapping global_indexes...") + batch_a = make_batch(list(range(3)), fields=["input_ids", "attention_mask"]) + batch_b = make_batch(list(range(2, 5)), fields=["input_ids", "attention_mask"]) + print(f" BatchA: {batch_a.global_indexes}, BatchB: {batch_b.global_indexes}") + unioned = batch_a.union(batch_b) + print(f"✓ Unioned: {unioned.global_indexes} (global_index=2 deduplicated)") + + # --- 10. Empty BatchMeta --- + print("[Example 10] Creating an empty BatchMeta...") + empty = BatchMeta.empty() + print(f"✓ Empty: size={empty.size}, is_ready={empty.is_ready}") print("=" * 80) print("concat vs union:") - print(" - concat: Combines multiple batches with SAME structure into one larger batch") - print(" Example: batch1[0,1,2] + batch2[3,4,5] = batch[0,1,2,3,4,5]") - print(" - union: Merges fields from two batches with IDENTICAL samples") - print(" Example: batch1[0,1] with fields A + batch2[0,1] with fields B = batch[0,1] with fields A+B") + print(" - concat: Combines batches with SAME field structure") + print(" - union: Merges batches, deduplicating by global_index") print("=" * 80) @@ -351,14 +289,13 @@ def demonstrate_real_workflow(): print(f" Number of samples: {len(batch_meta)}") print(f" Global indexes: {batch_meta.global_indexes}") print(f" Field names: {batch_meta.field_names}") - print(f" Partition ID: {batch_meta.samples[0].partition_id}") - print(f" Sample structure: {batch_meta.samples[0]}") + print(f" Partition IDs: {batch_meta.partition_ids}") print(f" Custom Meta: {batch_meta.get_all_custom_meta()}") print("[Step 4] Retrieve samples with specific fields..") selected_meta = batch_meta.select_fields(["input_ids"]) print("✓ Selected 'input_ids' field only:") - print(f" New field names: {selected_meta.field_names}") + print(f" Field names in new BatchMeta: {selected_meta.field_names}") print(f" Samples still have same global indexes: {selected_meta.global_indexes}") retrieved_data = tq_client.get_data(selected_meta) print(f" Retrieved data keys: {list(retrieved_data.keys())}") @@ -366,7 +303,7 @@ def demonstrate_real_workflow(): print("[Step 5] Select specific samples from the retrieved BatchMeta...") partial_meta = batch_meta.select_samples([0, 2, 4, 6]) print("✓ Selected samples at indices [0, 2, 4, 6]:") - print(f" New global indexes: {partial_meta.global_indexes}") + print(f" Global indexes in new BatchMeta: {partial_meta.global_indexes}") print(f" Number of samples: {len(partial_meta)}") retrieved_data = tq_client.get_data(partial_meta) print(f" Retrieved data samples: {retrieved_data}, all the data samples: {data_batch}") @@ -397,37 +334,31 @@ def main(): This script introduces the metadata system in TransferQueue, which tracks the structure and state of data: - 1. FieldMeta - Describes a single field (name, dtype, shape, production status) - 2. SampleMeta - Describes a single data sample (partition_id, global_index, fields) - 3. BatchMeta - Describes a batch of samples (collection of SampleMeta with operations) - Key Concepts: - - Metadata tracks data structure without storing actual data - - User can set their own custom metadata into BatchMeta, and use TQ controller to store them. - - BatchMeta provides operations: chunk, concat, union, select, reorder... - - Metadata is lightweight and can be passed around efficiently - - Union requires samples to have identical partition_id and global_index - """ + - BatchMeta stores global_indexes, partition_ids, and field_schema directly + - field_schema: dict[field_name, {dtype, shape, is_nested, is_non_tensor}] + - custom_meta: list[dict] aligned with global_indexes (one dict per sample) + - Metadata operations: chunk, concat, union, select_fields, select_samples, reorder + - batch.samples[i] returns a lazy view with .fields -> field_schema (read-only) + """ ) ) print("=" * 80) try: - demonstrate_field_meta() - demonstrate_sample_meta() - demonstrate_batch_meta() + demonstrate_batch_meta_schema() + demonstrate_batch_meta_operations() demonstrate_real_workflow() print("=" * 80) print("Tutorial Complete!") print("=" * 80) print("Key Takeaways:") - print("1. FieldMeta describes individual data fields (NO batch dimension in shape)") - print("2. SampleMeta describes a single data sample") - print("3. BatchMeta manages collections of samples with operations") - print("4. Metadata operations: chunk, concat, union, select, reorder... You can retrieve subsets easily!") - print("5. extra_info is in batch-level, and custom_meta is in sample-level.") - print("6. You can put custom_meta into TQ controller, so you can retrieve them from anywhere!") + print("1. BatchMeta uses columnar storage") + print("2. Construct BatchMeta with: BatchMeta(global_indexes=[...], partition_ids=[...], field_schema={...})") + print("3. BatchMeta operations: chunk, concat, union, select_fields, select_samples, reorder") + print("4. extra_info is batch-level; custom_meta is sample-level (list[dict])") + print("5. Store custom_meta via TQ controller: tq_client.set_custom_meta(batch_meta)") # Cleanup ray.shutdown()