Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions scripts/put_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@
parent_dir = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(parent_dir))

from transfer_queue import ( # noqa: E402
AsyncTransferQueueClient,
SimpleStorageUnit,
TransferQueueController,
process_zmq_server_info,
)
from transfer_queue import TransferQueueClient # noqa: E402
from transfer_queue.controller import TransferQueueController # noqa: E402
from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402
from transfer_queue.utils.common import get_placement_group # noqa: E402
from transfer_queue.utils.zmq_utils import process_zmq_server_info # noqa: E402

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -309,7 +307,7 @@ def initialize_system(self, config_dict):
self.tq_config = OmegaConf.merge(tq_internal_conf, self.tq_config)

# Client Init
self.data_system_client = AsyncTransferQueueClient(
self.data_system_client = TransferQueueClient(
client_id="Trainer", controller_info=self.data_system_controller_info
)
self.data_system_client.initialize_storage_manager(
Expand Down
77 changes: 69 additions & 8 deletions tests/e2e/test_e2e_lifecycle_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,15 +420,22 @@ def test_cross_shard_complex_update(e2e_client):
"Region 30-39 tensor_f32 should match original Put B"
)

# 9. Verify new fields exist in update region
extended_fields = base_fields + ["new_extra_tensor", "new_extra_non_tensor"]
update_region_meta = poll_for_meta(
client, partition_id, extended_fields, 20, "update_region_task", mode="force_fetch"
# 9. Verify new fields exist in update region (indices 10-29 only have new fields).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very hard to understand

# Build extended_meta from full_meta (which has valid _custom_backend_meta)
# by selecting the subset of samples whose global_indexes match meta_update.
# Using meta_update directly would fail because it was derived from alloc_meta
# before put(), so its _custom_backend_meta may be incomplete.
update_gis = set(meta_update.global_indexes)
update_positions_in_full = [
i for i, global_index in enumerate(full_meta.global_indexes) if global_index in update_gis
]
update_meta_with_backend = full_meta.select_samples(update_positions_in_full)
extended_meta = update_meta_with_backend.with_data_fields(
Copy link
Copy Markdown
Collaborator

@0oshowero0 0oshowero0 Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this new interface is needed?

base_fields + ["new_extra_tensor", "new_extra_non_tensor"]
)
if update_region_meta is not None and update_region_meta.size > 0:
update_region_data = client.get_data(update_region_meta)
assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist"
assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist"
update_region_data = client.get_data(extended_meta)
assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist"
assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist"
finally:
client.clear_partition(partition_id)

Expand Down Expand Up @@ -641,5 +648,59 @@ def test_clear_partition(e2e_client):
pass


# Scenario Six: Dynamic Tensor Shape → Nested Tensor Transition
def test_dynamic_tensor_shape_nested_transition(e2e_client):
"""
Test transition from regular tensor to nested tensor.
First put tensors of identical shape, then put tensors of a different shape.
Verify that the field schema marks is_nested=True, and getting all samples returns a nested tensor.
"""
client = e2e_client
partition_id = "test_nested_transition_partition"
task_name = "test_task"

try:
# 1. Put same-shape tensor (shape: (2, 4)) — initial insert
data1 = TensorDict({"dynamic_feature": torch.ones(2, 4)}, batch_size=2)
meta1_put = client.put(data=data1, partition_id=partition_id)
assert meta1_put.size == 2

# Poll and verify first batch is regular tensor
meta1 = poll_for_meta(client, partition_id, ["dynamic_feature"], 2, task_name, mode="force_fetch")
assert not meta1.field_schema["dynamic_feature"]["is_nested"]
retrieved_1 = client.get_data(meta1)
assert not retrieved_1["dynamic_feature"].is_nested
assert retrieved_1["dynamic_feature"].shape == (2, 4)

# 2. Allocate 2 more slots via insert mode, put different-shape tensor (shape: (2, 6))
alloc_meta2 = client.get_meta(
partition_id=partition_id,
data_fields=["dynamic_feature"],
batch_size=2,
mode="insert",
task_name="allocator",
)
assert alloc_meta2.size == 2
data2 = TensorDict({"dynamic_feature": torch.ones(2, 6)}, batch_size=2)
client.put(data=data2, metadata=alloc_meta2)

# Poll and verify metadata now indicates nested tensor
meta2 = poll_for_meta(client, partition_id, ["dynamic_feature"], 2, task_name, mode="force_fetch")

# After second put with different shape, is_nested should be True
assert meta2.field_schema["dynamic_feature"]["is_nested"] is True

# 3. Retrieve all 4 samples together
meta_all = poll_for_meta(client, partition_id, ["dynamic_feature"], 4, task_name, mode="force_fetch")
assert meta_all.field_schema["dynamic_feature"]["is_nested"] is True

retrieved_all = client.get_data(meta_all)
# The merged result should be a nested tensor since the shapes vary
assert retrieved_all["dynamic_feature"].is_nested is True
assert len(retrieved_all["dynamic_feature"]) == 4
finally:
client.clear_partition(partition_id)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
Loading