From a8e143c833b97913c2341bbc55d4cdc611cc8c56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=81=E6=9C=AC=E5=93=B2?= Date: Tue, 7 Apr 2026 03:30:32 +0000 Subject: [PATCH 1/2] [fix] Allow None values in _pack_field_values and fallback to NonTensorStack MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Modify _pack_field_values to tolerate None placeholders in the values list, falling back to NonTensorStack instead of raising ValueError. - Pure tensor lists (no None) still use torch.stack or nested tensor. - Update docstring to reflect the new None-tolerant behavior. Signed-off-by: 宁本哲 --- .../managers/simple_backend_manager.py | 44 ++++++++++++------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 27e173c7..df536801 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -387,24 +387,38 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack: """ Pack a list of per-sample values into a batched container. - For tensor values, this performs a memory copy via stacking or nested tensor creation. - Non-tensor values are grouped into a ``NonTensorStack`` without copying. + For pure tensor lists (no None), this performs a memory copy via stacking + or nested tensor creation. Mixed types, non-tensor values, or lists + containing None placeholders are grouped into a ``NonTensorStack``. + + Args: + values: List of per-sample values to pack. May contain None for + unfilled batch positions. + + Returns: + A stacked ``torch.Tensor`` (or nested tensor) when all values are + tensors, otherwise a ``NonTensorStack``. + + Raises: + ValueError: If *values* is empty. """ if not values: raise ValueError("_pack_field_values received empty values list; caller should filter empty batches") - if any(v is None for v in values): - raise ValueError("_pack_field_values received None in values list; some batch positions were not filled") - if all(isinstance(v, torch.Tensor) for v in values): - if all(v.shape == values[0].shape for v in values): - return torch.stack(values) - try: - return torch.nested.as_nested_tensor(values, layout=torch.jagged) - except (RuntimeError, TypeError) as e: - logger.warning( - f"Failed to pack nested tensor with jagged layout. " - f"Falling back to strided layout. Detailed error: {e}" - ) - return torch.nested.as_nested_tensor(values, layout=torch.strided) + non_none = [v for v in values if v is not None] + if non_none and all(isinstance(v, torch.Tensor) for v in non_none): + if not any(v is None for v in values): + # Pure tensor list — try stacking / nested tensor + if all(v.shape == values[0].shape for v in values): + return torch.stack(values) + try: + return torch.nested.as_nested_tensor(values, layout=torch.jagged) + except (RuntimeError, TypeError) as e: + logger.warning( + f"Failed to pack nested tensor with jagged layout. " + f"Falling back to strided layout. Detailed error: {e}" + ) + return torch.nested.as_nested_tensor(values, layout=torch.strided) + # Mixed tensor + None — cannot stack, fall through to NonTensorStack return NonTensorStack(*values) async def get_data(self, metadata: BatchMeta) -> TensorDict: From 4797ca25a0c14058c427c57e7d40d644fbe7c6be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=81=E6=9C=AC=E5=93=B2?= Date: Wed, 8 Apr 2026 03:13:54 +0000 Subject: [PATCH 2/2] add test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 宁本哲 --- tests/test_async_simple_storage_manager.py | 30 +++++++++++++++++-- .../managers/simple_backend_manager.py | 2 +- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index bbf6d4b3..4d1419a6 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -542,7 +542,7 @@ class TestPackFieldValues: def test_uniform_tensors_to_stack(self): """Same-shape tensors → torch.stack.""" values = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] - result = AsyncSimpleStorageManager._pack_field_values(values) + result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined] assert isinstance(result, torch.Tensor) assert not result.is_nested assert result.shape == (2, 2) @@ -550,13 +550,37 @@ def test_uniform_tensors_to_stack(self): def test_variable_length_tensors_to_nested(self): """Different-shape tensors → nested tensor.""" values = [torch.tensor([1.0]), torch.tensor([2.0, 3.0])] - result = AsyncSimpleStorageManager._pack_field_values(values) + result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined] assert isinstance(result, torch.Tensor) assert result.is_nested def test_non_tensors_to_nontensorstack(self): """Non-tensor values → NonTensorStack.""" values = ["hello", "world"] - result = AsyncSimpleStorageManager._pack_field_values(values) + result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined] assert isinstance(result, NonTensorStack) assert result.tolist() == ["hello", "world"] + + def test_mixed_tensors_and_none_to_nontensorstack(self): + """Mixed tensor + None values should stay as NonTensorStack (no stacking).""" + t0 = torch.tensor([1.0, 2.0]) + t2 = torch.tensor([3.0, 4.0]) + values = [t0, None, t2] + + result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined] + + assert isinstance(result, NonTensorStack) + unpacked = result.tolist() + assert len(unpacked) == 3 + assert torch.equal(unpacked[0], t0) + assert unpacked[1] is None + assert torch.equal(unpacked[2], t2) + + def test_all_none_to_nontensorstack(self): + """All-None values should be preserved in NonTensorStack.""" + values = [None, None] + + result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined] + + assert isinstance(result, NonTensorStack) + assert result.tolist() == [None, None] diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index df536801..00d87822 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -406,7 +406,7 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack: raise ValueError("_pack_field_values received empty values list; caller should filter empty batches") non_none = [v for v in values if v is not None] if non_none and all(isinstance(v, torch.Tensor) for v in non_none): - if not any(v is None for v in values): + if len(non_none) == len(values): # Pure tensor list — try stacking / nested tensor if all(v.shape == values[0].shape for v in values): return torch.stack(values)